diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 90b3fa55..942a4273 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -19,6 +19,9 @@ jobs: uses: actions/checkout@v5 - name: Run unit tests run: go test -race -short + - name: Run connection state unit tests + run: go test -race -short + working-directory: connectionstate lint: runs-on: ubuntu-latest diff --git a/client_side_statement_test.go b/client_side_statement_test.go index f99609e6..f20067e2 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -25,12 +25,13 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/structpb" ) func TestStatementExecutor_StartBatchDdl(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() @@ -61,7 +62,7 @@ func TestStatementExecutor_StartBatchDdl(t *testing.T) { } func TestStatementExecutor_StartBatchDml(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() @@ -98,7 +99,7 @@ func TestStatementExecutor_StartBatchDml(t *testing.T) { } func TestStatementExecutor_RetryAbortsInternally(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -154,7 +155,7 @@ func TestStatementExecutor_RetryAbortsInternally(t *testing.T) { } func TestStatementExecutor_AutocommitDmlMode(t *testing.T) { - c := &conn{logger: noopLogger, connector: &connector{}} + c := &conn{logger: noopLogger, connector: &connector{}, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} _ = c.ResetSession(context.Background()) s := &statementExecutor{} ctx := context.Background() @@ -211,7 +212,7 @@ func TestStatementExecutor_AutocommitDmlMode(t *testing.T) { } func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) { - c := &conn{logger: noopLogger} + c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -282,7 +283,7 @@ func TestStatementExecutor_ReadOnlyStaleness(t *testing.T) { func TestShowCommitTimestamp(t *testing.T) { t.Parallel() - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() @@ -328,7 +329,7 @@ func TestShowCommitTimestamp(t *testing.T) { } func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -384,7 +385,7 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) { } func TestStatementExecutor_MaxCommitDelay(t *testing.T) { - c := &conn{logger: noopLogger} + c := &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} ctx := context.Background() for i, test := range []struct { @@ -457,7 +458,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) { {"", "tag-with-missing-opening-quote'", true}, {"", "'tag-with-missing-closing-quote", true}, } { - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{}, nil) @@ -517,7 +518,7 @@ func TestStatementExecutor_SetTransactionTag(t *testing.T) { func TestStatementExecutor_UsesExecOptions(t *testing.T) { ctx := context.Background() - c := &conn{retryAborts: true, logger: noopLogger} + c := &conn{retryAborts: true, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{})} s := &statementExecutor{} it, err := s.ShowTransactionTag(ctx, c, "", ExecOptions{DecodeOption: DecodeOptionProto, ReturnResultSetMetadata: true, ReturnResultSetStats: true}, nil) diff --git a/conn.go b/conn.go index addca502..929294bc 100644 --- a/conn.go +++ b/conn.go @@ -27,6 +27,7 @@ import ( adminapi "cloud.google.com/go/spanner/admin/database/apiv1" adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/api/iterator" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -231,6 +232,8 @@ type conn struct { execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error) execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) + // state contains the current ConnectionState for this connection. + state *connectionstate.ConnectionState // batch is the currently active DDL or DML batch on this connection. batch *batch // autoBatchDml determines whether DML statements should automatically @@ -244,11 +247,6 @@ type conn struct { // statements was correct. autoBatchDmlUpdateCountVerification bool - // autocommitDMLMode determines the type of DML to use when a single DML - // statement is executed on a connection. The default is Transactional, but - // it can also be set to PartitionedNonAtomic to execute the statement as - // Partitioned DML. - autocommitDMLMode AutocommitDMLMode // readOnlyStaleness is used for queries in autocommit mode and for read-only transactions. readOnlyStaleness spanner.TimestampBound // isolationLevel determines the default isolation level that is used for read/write @@ -308,7 +306,7 @@ func (c *conn) setRetryAbortsInternally(retry bool) (driver.Result, error) { } func (c *conn) AutocommitDMLMode() AutocommitDMLMode { - return c.autocommitDMLMode + return propertyAutocommitDmlMode.GetValueOrDefault(c.state) } func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error { @@ -320,7 +318,9 @@ func (c *conn) SetAutocommitDMLMode(mode AutocommitDMLMode) error { } func (c *conn) setAutocommitDMLMode(mode AutocommitDMLMode) (driver.Result, error) { - c.autocommitDMLMode = mode + if err := propertyAutocommitDmlMode.SetValue(c.state, mode, connectionstate.ContextUser); err != nil { + return nil, err + } return driver.ResultNoRows, nil } @@ -689,8 +689,9 @@ func (c *conn) ResetSession(_ context.Context) error { c.retryAborts = c.connector.retryAbortsInternally c.isolationLevel = c.connector.connectorConfig.IsolationLevel c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption + + _ = c.state.Reset(connectionstate.ContextUser) // TODO: Reset the following fields to the connector default - c.autocommitDMLMode = Transactional c.readOnlyStaleness = spanner.TimestampBound{} c.execOptions = ExecOptions{ DecodeToNativeArrays: c.connector.connectorConfig.DecodeToNativeArrays, @@ -887,7 +888,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp c.batch.statements = append(c.batch.statements, ss) res = &result{} } else { - dmlMode := c.autocommitDMLMode + dmlMode := c.AutocommitDMLMode() if execOptions.AutocommitDMLMode != Unspecified { dmlMode = execOptions.AutocommitDMLMode } @@ -1015,6 +1016,13 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e c.resetForRetry = false return c.tx, nil } + // Also start a transaction on the ConnectionState if the BeginTx call was successful. + defer func() { + if c.tx != nil { + _ = c.state.Begin() + } + }() + readOnlyTxOpts := c.getReadOnlyTransactionOptions() batchReadOnlyTxOpts := c.getBatchReadOnlyTransactionOptions() readWriteTransactionOptions := c.getTransactionOptions() @@ -1072,13 +1080,18 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e roTx: ro, boTx: bo, logger: logger, - close: func() { + close: func(result txResult) { if batchReadOnlyTxOpts.close != nil { batchReadOnlyTxOpts.close() } if readOnlyTxOpts.close != nil { readOnlyTxOpts.close() } + if result == txResultCommit { + _ = c.state.Commit() + } else { + _ = c.state.Rollback() + } c.tx = nil }, } @@ -1095,7 +1108,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e conn: c, logger: logger, rwTx: tx, - close: func(commitResponse *spanner.CommitResponse, commitErr error) { + close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) { if readWriteTransactionOptions.close != nil { readWriteTransactionOptions.close() } @@ -1103,6 +1116,13 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e c.tx = nil if commitErr == nil { c.commitResponse = commitResponse + if result == txResultCommit { + _ = c.state.Commit() + } else { + _ = c.state.Rollback() + } + } else { + _ = c.state.Rollback() } }, // Disable internal retries if any of these options have been set. diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index eaffb39d..8560acba 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -23,6 +23,7 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/connectionstate" "github.com/googleapis/go-sql-spanner/testutil" ) @@ -448,3 +449,105 @@ func TestSetRetryAbortsInternallyInActiveTransaction(t *testing.T) { } _ = tx.Rollback() } + +func TestSetAutocommitDMLMode(t *testing.T) { + t.Parallel() + + for _, tp := range []connectionstate.Type{connectionstate.TypeTransactional, connectionstate.TypeNonTransactional} { + db, _, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{ + Project: "p", + Instance: "i", + Database: "d", + ConnectionStateType: tp, + }) + defer teardown() + + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn.Close() }() + + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + if g, w := c.AutocommitDMLMode(), Transactional; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil { + t.Fatal(err) + } + if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + + // Set the value in a transaction and commit. + tx, err := conn.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + // The value should be the same as before the transaction started. + if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + // Changes in a transaction should be visible in the transaction. + if err := c.SetAutocommitDMLMode(Transactional); err != nil { + t.Fatal(err) + } + if g, w := c.AutocommitDMLMode(), Transactional; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + // Committing the transaction should make the change durable (and is a no-op if the connection state type is + // non-transactional). + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + if g, w := c.AutocommitDMLMode(), Transactional; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + + // Set the value in a transaction and rollback. + tx, err = conn.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + if err := c.SetAutocommitDMLMode(PartitionedNonAtomic); err != nil { + t.Fatal(err) + } + if g, w := c.AutocommitDMLMode(), PartitionedNonAtomic; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + // Rolling back the transaction will undo the change if the connection state is transactional. + // In case of non-transactional state, the rollback does not have an effect, as the state change was persisted + // directly when SetAutocommitDMLMode was called. + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + _ = conn.Raw(func(driverConn interface{}) error { + c, _ := driverConn.(SpannerConn) + var expected AutocommitDMLMode + if tp == connectionstate.TypeTransactional { + expected = Transactional + } else { + expected = PartitionedNonAtomic + } + if g, w := c.AutocommitDMLMode(), expected; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + return nil + }) + } +} diff --git a/connection_properties.go b/connection_properties.go new file mode 100644 index 00000000..bd38d042 --- /dev/null +++ b/connection_properties.go @@ -0,0 +1,53 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spannerdriver + +import "github.com/googleapis/go-sql-spanner/connectionstate" + +// connectionProperties contains all supported connection properties for Spanner. +// These properties are added to all connectionstate.ConnectionState instances that are created for Spanner connections. +var connectionProperties = map[string]connectionstate.ConnectionProperty{} + +// The following variables define the various connectionstate.ConnectionProperty instances that are supported and used +// by the Spanner database/sql driver. They are defined as global variables, so they can be used directly in the driver +// to get/set the state of exactly that property. + +var propertyConnectionStateType = createConnectionProperty( + "connection_state_type", + "The type of connection state to use for this connection. Can only be set at start up. "+ + "If no value is set, then the database dialect default will be used, "+ + "which is NON_TRANSACTIONAL for GoogleSQL and TRANSACTIONAL for PostgreSQL.", + connectionstate.TypeDefault, + []connectionstate.Type{connectionstate.TypeDefault, connectionstate.TypeTransactional, connectionstate.TypeNonTransactional}, + connectionstate.ContextStartup, +) +var propertyAutocommitDmlMode = createConnectionProperty( + "autocommit_dml_mode", + "Determines the transaction type that is used to execute DML statements when the connection is in auto-commit mode.", + Transactional, + []AutocommitDMLMode{Transactional, PartitionedNonAtomic}, + connectionstate.ContextUser, +) + +func createConnectionProperty[T comparable](name, description string, defaultValue T, validValues []T, context connectionstate.Context) *connectionstate.TypedConnectionProperty[T] { + prop := connectionstate.CreateConnectionProperty(name, description, defaultValue, validValues, context) + connectionProperties[prop.Key()] = prop + return prop +} + +func createInitialConnectionState(connectionStateType connectionstate.Type, initialValues map[string]connectionstate.ConnectionPropertyValue) *connectionstate.ConnectionState { + state, _ := connectionstate.NewConnectionState(connectionStateType, connectionProperties, initialValues) + return state +} diff --git a/connectionstate/connection_property.go b/connectionstate/connection_property.go new file mode 100644 index 00000000..af9d46f0 --- /dev/null +++ b/connectionstate/connection_property.go @@ -0,0 +1,344 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Context indicates when a ConnectionProperty may be set. +type Context int + +const ( + // ContextStartup is used for ConnectionProperty instances that may only be set at startup and may not be changed + // during the lifetime of a connection. + ContextStartup = iota + // ContextUser is used for ConnectionProperty instances that may be set both at startup and during the lifetime of + // a connection. + ContextUser +) + +func (c Context) String() string { + switch c { + case ContextStartup: + return "Startup" + case ContextUser: + return "User" + default: + return "Unknown" + } +} + +// ConnectionProperty defines the public interface for connection properties. +type ConnectionProperty interface { + // Key returns the unique key of the ConnectionProperty. This is equal to the name of the property for properties + // without an extension, and equal to `extension.name` for properties with an extension. + Key() string + // Context returns the Context where the property is allowed to be updated (e.g. only at startup or during the + // lifetime of a connection). + Context() Context + // CreateInitialValue creates an initial value of the property with the default value of the property as the current + // and reset value. + CreateInitialValue() ConnectionPropertyValue +} + +// CreateConnectionProperty is used to create a new ConnectionProperty with a specific type. This function is intended +// for use by driver implementations at initialization time to define the properties that the driver supports. +func CreateConnectionProperty[T comparable](name, description string, defaultValue T, validValues []T, context Context) *TypedConnectionProperty[T] { + return CreateConnectionPropertyWithExtension("", name, description, defaultValue, validValues, context) +} + +// CreateConnectionPropertyWithExtension is used to create a new ConnectionProperty with a specific type and an +// extension. Properties with an extension can be created dynamically during the lifetime of a connection. These are +// lost when the connection is reset to its original state. +func CreateConnectionPropertyWithExtension[T comparable](extension, name, description string, defaultValue T, validValues []T, context Context) *TypedConnectionProperty[T] { + var key string + if extension == "" { + key = name + } else { + key = extension + "." + name + } + return &TypedConnectionProperty[T]{ + key: key, + extension: extension, + name: name, + description: description, + defaultValue: defaultValue, + validValues: validValues, + context: context, + } +} + +var _ ConnectionProperty = (*TypedConnectionProperty[any])(nil) + +// TypedConnectionProperty implements the ConnectionProperty interface. +// All fields are unexported to ensure that the values can only be updated in accordance with the semantics of the +// chosen ConnectionState Type. +type TypedConnectionProperty[T comparable] struct { + key string + extension string + name string + description string + defaultValue T + validValues []T + context Context +} + +func (p *TypedConnectionProperty[T]) String() string { + return p.Key() +} + +func (p *TypedConnectionProperty[T]) Key() string { + return p.key +} + +func (p *TypedConnectionProperty[T]) Context() Context { + return p.context +} + +func (p *TypedConnectionProperty[T]) CreateInitialValue() ConnectionPropertyValue { + return &connectionPropertyValue[T]{ + connectionProperty: p, + resetValue: p.defaultValue, + value: p.defaultValue, + removeAtReset: false, + } +} + +// GetValueOrDefault returns the current value of the property in the given ConnectionState. +// It returns the default of the property if no value is found. +func (p *TypedConnectionProperty[T]) GetValueOrDefault(state *ConnectionState) T { + value, _ := state.value(p /*returnErrForUnknownProperty=*/, false) + if value == nil { + return p.defaultValue + } + if typedValue, ok := value.(*connectionPropertyValue[T]); ok { + return typedValue.value + } + return p.defaultValue +} + +// GetValueOrError returns the current value of the property in the given ConnectionState. +// It returns an error if no value is found. +func (p *TypedConnectionProperty[T]) GetValueOrError(state *ConnectionState) (T, error) { + value, err := state.value(p /*returnErrForUnknownProperty=*/, true) + if err != nil { + return p.zeroAndErr(err) + } + if value == nil { + return p.zeroAndErr(status.Errorf(codes.InvalidArgument, "no value found for property: %q", p)) + } + if typedValue, ok := value.(*connectionPropertyValue[T]); ok { + return typedValue.value, nil + } + return p.zeroAndErr(status.Errorf(codes.InvalidArgument, "value has wrong type: %s", value)) +} + +// ResetValue resets the value of the property in the given ConnectionState to its original value. +// +// The given Context should indicate the current context where the application tries to reset the value, e.g. it should +// be ContextUser if the reset happens during the lifetime of a connection, and ContextStartup if the reset happens at +// the creation of a connection. +func (p *TypedConnectionProperty[T]) ResetValue(state *ConnectionState, context Context) error { + value, _ := state.value(p /*returnErrForUnknownProperty=*/, false) + if value == nil { + var t T + return p.SetValue(state, t, context) + } else { + resetValue := value.GetResetValue() + typedResetValue, ok := resetValue.(T) + if !ok { + return status.Errorf(codes.InvalidArgument, "value has wrong type: %T", resetValue) + } + return p.SetValue(state, typedResetValue, context) + } +} + +// SetValue sets the value of the property in the given ConnectionState. +// +// The given Context should indicate the current context where the application tries to reset the value, e.g. it should +// be ContextUser if the reset happens during the lifetime of a connection, and ContextStartup if the reset happens at +// the creation of a connection. +func (p *TypedConnectionProperty[T]) SetValue(state *ConnectionState, value T, context Context) error { + if p.context < context { + return status.Errorf(codes.FailedPrecondition, "property has context %s and cannot be set in context %s", p.context, context) + } + if !state.inTransaction || state.connectionStateType == TypeNonTransactional || context < ContextUser { + // Set the value in non-transactional mode. + if err := p.setValue(state, state.properties, value, context); err != nil { + return err + } + // Remove the setting from the local settings if it's there, as the new setting is + // the one that should be used. + if state.localProperties != nil { + delete(state.localProperties, p.key) + } + return nil + } + // Set the value in a transaction. + if state.transactionProperties == nil { + state.transactionProperties = make(map[string]ConnectionPropertyValue) + } + if err := p.setValue(state, state.transactionProperties, value, context); err != nil { + return err + } + // Remove the setting from the local settings if it's there, as the new transaction setting is + // the one that should be used. + if state.localProperties != nil { + delete(state.localProperties, p.key) + } + return nil +} + +// SetLocalValue sets the local value of the property in the given ConnectionState. A local value is only visible +// for the remainder of the current transaction. The value is reset to the value it had before the transaction when the +// transaction ends, regardless whether the transaction committed or rolled back. +// +// Setting a local value outside a transaction is a no-op. +func (p *TypedConnectionProperty[T]) SetLocalValue(state *ConnectionState, value T) error { + if p.context < ContextUser { + return status.Error(codes.FailedPrecondition, "SetLocalValue is only supported for properties with context USER or higher") + } + if !state.inTransaction { + // SetLocalValue outside a transaction is a no-op. + return nil + } + if state.localProperties == nil { + state.localProperties = make(map[string]ConnectionPropertyValue) + } + return p.setValue(state, state.localProperties, value, ContextUser) +} + +func (p *TypedConnectionProperty[T]) setValue(state *ConnectionState, currentProperties map[string]ConnectionPropertyValue, value T, context Context) error { + if err := p.checkValidValue(value); err != nil { + return err + } + newValue, ok := currentProperties[p.key] + if !ok { + existingValue, ok := state.properties[p.key] + if !ok { + if p.extension == "" { + return unknownPropertyErr(p) + } + newValue = &connectionPropertyValue[T]{connectionProperty: p, removeAtReset: true} + } else { + newValue = existingValue.Copy() + } + } + if err := newValue.SetValue(value, context); err != nil { + return err + } + currentProperties[p.key] = newValue + return nil +} + +func (p *TypedConnectionProperty[T]) zeroAndErr(err error) (T, error) { + var t T + return t, err +} + +func (p *TypedConnectionProperty[T]) checkValidValue(value T) error { + if p.validValues == nil { + return nil + } + for _, validValue := range p.validValues { + if value == validValue { + return nil + } + } + return status.Errorf(codes.InvalidArgument, "value %v is not a valid value for %s", value, p) +} + +func unknownPropertyErr(p ConnectionProperty) error { + return status.Errorf(codes.InvalidArgument, "unrecognized configuration property %q", p.Key()) +} + +// ConnectionPropertyValue is the public interface for connection state property values. +type ConnectionPropertyValue interface { + // ConnectionProperty returns the property that this value is for. + ConnectionProperty() ConnectionProperty + // Copy creates a shallow copy of the ConnectionPropertyValue. + Copy() ConnectionPropertyValue + // SetValue sets the value of the property. The given value must be a valid value for the property. + SetValue(value any, context Context) error + // ResetValue resets the value of the property to the value it had at the creation of the connection. + ResetValue(context Context) error + // RemoveAtReset indicates whether the value should be removed from the ConnectionState when the ConnectionState is + // reset. This function should return true for property values that have been added to the set after the connection + // was created, for example because the user executed `set my_extension.my_property='some-value'`. + RemoveAtReset() bool + // GetResetValue returns the value that will be assigned to this property value if the value is reset. + GetResetValue() any +} + +// CreateInitialValue creates an initial value for a property. Both the current and the reset value are set to the given +// value. +func CreateInitialValue[T comparable](property *TypedConnectionProperty[T], value T) ConnectionPropertyValue { + return &connectionPropertyValue[T]{ + connectionProperty: property, + value: value, + resetValue: value, + removeAtReset: false, + } +} + +type connectionPropertyValue[T comparable] struct { + connectionProperty *TypedConnectionProperty[T] + resetValue T + value T + removeAtReset bool +} + +func (v *connectionPropertyValue[T]) ConnectionProperty() ConnectionProperty { + return v.connectionProperty +} + +func (v *connectionPropertyValue[T]) Copy() ConnectionPropertyValue { + return &connectionPropertyValue[T]{ + connectionProperty: v.connectionProperty, + resetValue: v.resetValue, + value: v.value, + removeAtReset: v.removeAtReset, + } +} + +func (v *connectionPropertyValue[T]) SetValue(value any, context Context) error { + if v.connectionProperty.context < context { + return status.Errorf(codes.FailedPrecondition, "property has context %s and cannot be set in context %s", v.connectionProperty.context, context) + } + typedValue, ok := value.(T) + if !ok { + return status.Errorf(codes.InvalidArgument, "value has wrong type: %T", value) + } + v.value = typedValue + return nil +} + +func (v *connectionPropertyValue[T]) ResetValue(context Context) error { + if v.connectionProperty.context < context { + return status.Errorf(codes.FailedPrecondition, "property has context %s and cannot be set in context %s", v.connectionProperty.context, context) + } + v.value = v.resetValue + return nil +} + +func (v *connectionPropertyValue[T]) RemoveAtReset() bool { + return v.removeAtReset +} + +func (v *connectionPropertyValue[T]) GetResetValue() any { + return v.resetValue +} diff --git a/connectionstate/connection_property_test.go b/connectionstate/connection_property_test.go new file mode 100644 index 00000000..232c4b46 --- /dev/null +++ b/connectionstate/connection_property_test.go @@ -0,0 +1,46 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "testing" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestCheckValidValue(t *testing.T) { + type testValue int + const ( + testValueUnspecified testValue = iota + testValueTrue + testValueFalse + ) + + p := &TypedConnectionProperty[testValue]{validValues: []testValue{testValueTrue, testValueFalse}} + if err := p.checkValidValue(testValueTrue); err != nil { + t.Fatal(err) + } + if err := p.checkValidValue(testValueFalse); err != nil { + t.Fatal(err) + } + if err := p.checkValidValue(testValueUnspecified); err == nil { + t.Fatalf("expected error for %v", testValueUnspecified) + } else { + if g, w := status.Code(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +} diff --git a/connectionstate/connection_state.go b/connectionstate/connection_state.go new file mode 100644 index 00000000..a09959ab --- /dev/null +++ b/connectionstate/connection_state.go @@ -0,0 +1,163 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Type represents the type of ConnectionState that is used. +// ConnectionState can be either transactional or non-transactional. +// Transactional ConnectionState means that any changes to the state +// during a transaction will only be persisted and visible after the +// transaction if the transaction is committed. +// Non-transactional ConnectionState means that changes that are made +// to the state during a transaction are persisted directly and will +// be visible after the transaction regardless whether the transaction +// was committed or rolled back. +type Type int + +const ( + // TypeDefault indicates that the default ConnectionState type of the database dialect should be used. + // GoogleSQL uses non-transactional ConnectionState by default. + // PostgreSQL uses transactional ConnectionState by default. + TypeDefault Type = iota + // TypeTransactional ConnectionState means that changes to the state during a transaction will only be + // persisted and visible after the transaction if the transaction is committed. + TypeTransactional + // TypeNonTransactional ConnectionState means that changes to the state during a transaction are persisted + // directly and remain visible after the transaction, regardless whether the transaction committed or not. + TypeNonTransactional +) + +// ConnectionState contains connection the state of a connection in a map. +type ConnectionState struct { + connectionStateType Type + inTransaction bool + properties map[string]ConnectionPropertyValue + transactionProperties map[string]ConnectionPropertyValue + localProperties map[string]ConnectionPropertyValue +} + +// NewConnectionState creates a new ConnectionState instance with the given initial values. +// The Type must be either TypeTransactional or TypeNonTransactional. +func NewConnectionState(connectionStateType Type, properties map[string]ConnectionProperty, initialValues map[string]ConnectionPropertyValue) (*ConnectionState, error) { + if connectionStateType == TypeDefault { + return nil, status.Error(codes.InvalidArgument, "connection state type cannot be TypeDefault") + } + state := &ConnectionState{ + connectionStateType: connectionStateType, + properties: make(map[string]ConnectionPropertyValue), + transactionProperties: nil, + localProperties: nil, + } + for key, value := range initialValues { + state.properties[key] = value.Copy() + } + for key, value := range properties { + if _, ok := state.properties[key]; !ok { + state.properties[key] = value.CreateInitialValue() + } + } + return state, nil +} + +// Begin starts a new transaction for this ConnectionState. +func (cs *ConnectionState) Begin() error { + if cs.inTransaction { + return status.Error(codes.FailedPrecondition, "connection state is already in transaction") + } + cs.inTransaction = true + return nil +} + +// Commit the current ConnectionState. +// This resets all local property values to the value they had before the transaction. +// If the Type is TypeTransactional, then all pending state changes are committed. +// If the Type is TypeNonTransactional, all state changes during the transaction have already been persisted during the +// transaction. +func (cs *ConnectionState) Commit() error { + if !cs.inTransaction { + return status.Error(codes.FailedPrecondition, "connection state is not in a transaction") + } + cs.inTransaction = false + if cs.transactionProperties != nil { + for key, value := range cs.transactionProperties { + cs.properties[key] = value + } + } + cs.transactionProperties = nil + cs.localProperties = nil + return nil +} + +// Rollback the current transactional state. +// This resets all local property values to the value they had before the transaction. +// If the Type is TypeTransactional, then all pending state changes are rolled back. +// If the Type is TypeNonTransactional, all state changes during the transaction have already been persisted during the +// transaction. +func (cs *ConnectionState) Rollback() error { + if !cs.inTransaction { + return status.Error(codes.FailedPrecondition, "connection state is not in a transaction") + } + cs.inTransaction = false + cs.transactionProperties = nil + cs.localProperties = nil + return nil +} + +// Reset the state to the initial values. Only the properties with a Context equal to or higher +// than the given Context will be reset. E.g. if the given Context is ContextUser, then properties +// with ContextStartup will not be reset. +func (cs *ConnectionState) Reset(context Context) error { + cs.transactionProperties = nil + cs.localProperties = nil + var remove map[string]bool + for _, value := range cs.properties { + if value.ConnectionProperty().Context() >= context { + if value.RemoveAtReset() { + if remove == nil { + remove = make(map[string]bool) + } + remove[value.ConnectionProperty().Key()] = true + } else { + if err := value.ResetValue(context); err != nil { + return err + } + } + } + } + for key := range remove { + delete(cs.properties, key) + } + return nil +} + +func (cs *ConnectionState) value(property ConnectionProperty, returnErrForUnknownProperty bool) (ConnectionPropertyValue, error) { + if val, ok := cs.localProperties[property.Key()]; ok { + return val, nil + } + if val, ok := cs.transactionProperties[property.Key()]; ok { + return val, nil + } + if val, ok := cs.properties[property.Key()]; ok { + return val, nil + } + if returnErrForUnknownProperty { + return nil, status.Errorf(codes.InvalidArgument, "unrecognized configuration property %q", property.Key()) + } + return nil, nil +} diff --git a/connectionstate/connection_state_test.go b/connectionstate/connection_state_test.go new file mode 100644 index 00000000..6a04369c --- /dev/null +++ b/connectionstate/connection_state_test.go @@ -0,0 +1,491 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionstate + +import ( + "testing" + + "cloud.google.com/go/spanner" + "google.golang.org/grpc/codes" +) + +func TestSetValueOutsideTransaction(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + for _, setToValue := range []string{"new-value", ""} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetValue(state, setToValue, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestSetValueInTransactionAndCommit(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + for _, setToValue := range []string{"new-value", ""} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetValue(state, setToValue, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is persisted if the transaction is committed. + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestSetValueInTransactionAndRollback(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + for _, setToValue := range []string{"new-value", ""} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetValue(state, setToValue, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), setToValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is rolled back if the transaction is rolled back and the connection + // state is transactional. + if err := state.Rollback(); err != nil { + t.Fatal(err) + } + var expected string + if tp == TypeTransactional { + expected = prop.defaultValue + } else { + expected = setToValue + } + if g, w := prop.GetValueOrDefault(state), expected; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestResetValueOutsideTransaction(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = prop.SetValue(state, "new-value", ContextUser) + if err := prop.ResetValue(state, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestResetValueInTransactionAndCommit(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + + // Change the value to something else than the default and commit. + _ = prop.SetValue(state, "new-value", ContextUser) + _ = state.Commit() + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Now reset the value to the default in a new transaction. + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if err := prop.ResetValue(state, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is persisted if the transaction is committed. + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestResetValueInTransactionAndRollback(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + + // Change the value to something else than the default and commit. + _ = prop.SetValue(state, "new-value", ContextUser) + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Now reset the value to the default in a new transaction. + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if err := prop.ResetValue(state, ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is rolled back if the transaction is rolled back and the connection + // state is transactional. + if err := state.Rollback(); err != nil { + t.Fatal(err) + } + // The value should be rolled back to "new-value" if the state is transactional. + // The value should be "initial-value" if the state is non-transactional, as the Reset is persisted regardless + // whether the transaction committed or not. + var expected string + if tp == TypeTransactional { + expected = "new-value" + } else { + expected = prop.defaultValue + } + if g, w := prop.GetValueOrDefault(state), expected; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetLocalValue(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + if err := state.Begin(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetLocalValue(state, "new-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the change is no longer visible once the transaction has ended, even if the + // transaction was committed. + if err := state.Commit(); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetLocalValueOutsideTransaction(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + // Setting a local value outside a transaction is a no-op. + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + if err := prop.SetLocalValue(state, "new-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetLocalValueForStartupProperty(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextStartup) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + err := prop.SetLocalValue(state, "new-value") + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetInTransactionForStartupProperty(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextStartup) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = state.Begin() + err := prop.SetValue(state, "new-value", ContextUser) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetStartupProperty(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextStartup) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + err := prop.SetValue(state, "new-value", ContextUser) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestSetNormalAndLocalValue(t *testing.T) { + prop := CreateConnectionProperty("my_property", "Test property", "initial-value", nil, ContextUser) + properties := map[string]ConnectionProperty{ + "my_property": prop, + } + + for _, commit := range []bool{true, false} { + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = state.Begin() + if g, w := prop.GetValueOrDefault(state), prop.defaultValue; g != w { + t.Fatalf("initial value mismatch\n Got: %v\nWant: %v", g, w) + } + // First set a local value. + if err := prop.SetLocalValue(state, "local-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "local-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + // Then set a 'standard' value in a transaction. + // This should override the local value. + if err := prop.SetValue(state, "new-value", ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + // Then set yet another local value. This should take precedence within the current transaction. + if err := prop.SetLocalValue(state, "second-local-value"); err != nil { + t.Fatal(err) + } + if g, w := prop.GetValueOrDefault(state), "second-local-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the last local value is lost when the transaction ends. + // The value that was set in the transaction should be persisted if + // the transaction is committed, or if we are using non-transactional state. + if commit { + _ = state.Commit() + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + _ = state.Rollback() + if tp == TypeNonTransactional { + // The transaction was rolled back, but this should have no impact on the + // SET statement in the transaction. So the new-value should be persisted. + if g, w := prop.GetValueOrDefault(state), "new-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + // The transaction was rolled back. The value should have been reset to + // the value it had before the transaction. + if g, w := prop.GetValueOrDefault(state), "initial-value"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } + } + } + } +} + +func TestSetUnknownProperty(t *testing.T) { + properties := map[string]ConnectionProperty{} + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + // Try to add a new property without an extension to the connection state. + // This should fail. + prop := CreateConnectionProperty("prop", "Test property", "initial-value", nil, ContextUser) + err := prop.SetValue(state, "new-value", ContextUser) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + // Adding a new property with an extension to the connection state should work. + propWithExtension := CreateConnectionPropertyWithExtension("spanner", "prop3", "Test property 3", "initial-value-3", nil, ContextUser) + if err := propWithExtension.SetValue(state, "new-value", ContextUser); err != nil { + t.Fatal(err) + } + } +} + +func TestReset(t *testing.T) { + prop1 := CreateConnectionProperty("prop1", "Test property 1", "initial-value-1", nil, ContextUser) + prop2 := CreateConnectionProperty("prop2", "Test property 2", "initial-value-2", nil, ContextUser) + properties := map[string]ConnectionProperty{ + prop1.key: prop1, + prop2.key: prop2, + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, map[string]ConnectionPropertyValue{}) + _ = prop1.SetValue(state, "new-value-1", ContextUser) + _ = prop2.SetValue(state, "new-value-2", ContextUser) + // Add a new property to the connection state. This will be removed when the state is reset. + prop3 := CreateConnectionPropertyWithExtension("spanner", "prop3", "Test property 3", "initial-value-3", nil, ContextUser) + if err := prop3.SetValue(state, "new-value-3", ContextUser); err != nil { + t.Fatal(err) + } + + if g, w := prop1.GetValueOrDefault(state), "new-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "new-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop3.GetValueOrDefault(state), "new-value-3"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := state.Reset(ContextUser); err != nil { + t.Fatal(err) + } + } +} + +func TestResetWithInitialValues(t *testing.T) { + prop1 := CreateConnectionProperty("prop1", "Test property 1", "default-value-1", nil, ContextUser) + prop2 := CreateConnectionProperty("prop2", "Test property 2", "default-value-2", nil, ContextUser) + properties := map[string]ConnectionProperty{ + prop1.key: prop1, + prop2.key: prop2, + } + initialValues := map[string]ConnectionPropertyValue{ + prop2.key: CreateInitialValue(prop2, "initial-value-2"), + } + + for _, tp := range []Type{TypeTransactional, TypeNonTransactional} { + state, _ := NewConnectionState(tp, properties, initialValues) + // Verify that prop1 has the default value of the property, and that prop2 has the initial value that was + // passed in when the ConnectionState was created. + if g, w := prop1.GetValueOrDefault(state), "default-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "initial-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that updating the values works. + _ = prop1.SetValue(state, "new-value-1", ContextUser) + _ = prop2.SetValue(state, "new-value-2", ContextUser) + + if g, w := prop1.GetValueOrDefault(state), "new-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "new-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + + // Resetting the values should bring them back to the original values. + if err := state.Reset(ContextUser); err != nil { + t.Fatal(err) + } + if g, w := prop1.GetValueOrDefault(state), "default-value-1"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := prop2.GetValueOrDefault(state), "initial-value-2"; g != w { + t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w) + } + } +} diff --git a/driver.go b/driver.go index 4a4eb6f0..6772a030 100644 --- a/driver.go +++ b/driver.go @@ -39,6 +39,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/uuid" "github.com/googleapis/gax-go/v2" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" @@ -293,6 +294,17 @@ type ConnectorConfig struct { // any of those do not yet exist. AutoConfigEmulator bool + // ConnectionStateType determines the behavior of changes to connection state + // during a transaction. + // connectionstate.TypeTransactional means that changes during a transaction + // are only persisted if the transaction is committed. If the transaction is + // rolled back, any changes to the connection state during the transaction + // will be lost. + // connectionstate.TypeNonTransactional means that changes to the connection + // state during a transaction are persisted directly, and are always visible + // after the transaction, regardless whether the transaction was committed or + // rolled back. + ConnectionStateType connectionstate.Type // Params contains key/value pairs for commonly used configuration parameters // for connections. The valid values are the same as the parameters that can // be added to a connection string. @@ -678,6 +690,11 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { connId := uuid.New().String() logger := c.logger.With("connId", connId) + connectionStateType := c.connectorConfig.ConnectionStateType + if connectionStateType == connectionstate.TypeDefault { + // TODO: Determine the default type of connection state based on the dialect + connectionStateType = connectionstate.TypeNonTransactional + } connection := &conn{ parser: c.parser, connector: c, @@ -688,6 +705,8 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { database: databaseName, retryAborts: c.retryAbortsInternally, + // TODO: Pass in initial values for the connection state + state: createInitialConnectionState(connectionStateType, map[string]connectionstate.ConnectionPropertyValue{}), autoBatchDml: c.connectorConfig.AutoBatchDml, autoBatchDmlUpdateCount: c.connectorConfig.AutoBatchDmlUpdateCount, autoBatchDmlUpdateCountVerification: !c.connectorConfig.DisableAutoBatchDmlUpdateCountVerification, diff --git a/driver_test.go b/driver_test.go index 31c0fcc9..e3c84dfa 100644 --- a/driver_test.go +++ b/driver_test.go @@ -31,6 +31,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/connectionstate" "google.golang.org/api/option" "google.golang.org/grpc/codes" ) @@ -486,12 +487,13 @@ func TestConnection_Reset(t *testing.T) { connector: &connector{ connectorConfig: ConnectorConfig{}, }, + state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), readOnlyStaleness: spanner.ExactStaleness(time.Second), batch: &batch{tp: dml}, commitResponse: &spanner.CommitResponse{}, tx: &readOnlyTransaction{ logger: noopLogger, - close: func() { + close: func(_ txResult) { txClosed = true }, }, @@ -517,6 +519,7 @@ func TestConnection_Reset(t *testing.T) { func TestConnection_NoNestedTransactions(t *testing.T) { c := conn{ logger: noopLogger, + state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &readOnlyTransaction{}, } _, err := c.BeginTx(context.Background(), driver.TxOptions{}) @@ -618,10 +621,10 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) { t.Fatal(err) } c := &conn{ - parser: parser, - logger: noopLogger, - autocommitDMLMode: Transactional, - batch: &batch{tp: ddl}, + parser: parser, + logger: noopLogger, + state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), + batch: &batch{tp: ddl}, execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, @@ -749,9 +752,9 @@ func TestConn_GetCommitResponseAfterAutocommitDml(t *testing.T) { } want := &spanner.CommitResponse{CommitTs: time.Now()} c := &conn{ - parser: parser, - logger: noopLogger, - autocommitDMLMode: Transactional, + parser: parser, + logger: noopLogger, + state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, diff --git a/transaction.go b/transaction.go index 5fc2b3c3..b83061b8 100644 --- a/transaction.go +++ b/transaction.go @@ -91,13 +91,20 @@ func (ri *readOnlyRowIterator) ResultSetStats() *sppb.ResultSetStats { } } +type txResult int + +const ( + txResultCommit txResult = iota + txResultRollback +) + var _ contextTransaction = &readOnlyTransaction{} type readOnlyTransaction struct { roTx *spanner.ReadOnlyTransaction boTx *spanner.BatchReadOnlyTransaction logger *slog.Logger - close func() + close func(result txResult) } func (tx *readOnlyTransaction) Commit() error { @@ -109,7 +116,7 @@ func (tx *readOnlyTransaction) Commit() error { } else if tx.roTx != nil { tx.roTx.Close() } - tx.close() + tx.close(txResultCommit) return nil } @@ -120,7 +127,7 @@ func (tx *readOnlyTransaction) Rollback() error { if tx.roTx != nil { tx.roTx.Close() } - tx.close() + tx.close(txResultRollback) return nil } @@ -233,7 +240,7 @@ type readWriteTransaction struct { active bool // batch is any DML batch that is active for this transaction. batch *batch - close func(commitResponse *spanner.CommitResponse, commitErr error) + close func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) // retryAborts indicates whether this transaction will automatically retry // the transaction if it is aborted by Spanner. The default is true. retryAborts bool @@ -409,7 +416,7 @@ func (tx *readWriteTransaction) Commit() (err error) { if tx.rwTx != nil { if !tx.retryAborts { ts, err := tx.rwTx.CommitWithReturnResp(tx.ctx) - tx.close(&ts, err) + tx.close(txResultCommit, &ts, err) return err } @@ -421,7 +428,7 @@ func (tx *readWriteTransaction) Commit() (err error) { tx.rwTx.Rollback(context.Background()) } } - tx.close(&commitResponse, err) + tx.close(txResultCommit, &commitResponse, err) return err } @@ -439,7 +446,7 @@ func (tx *readWriteTransaction) rollback(ctx context.Context) error { if tx.rwTx != nil { tx.rwTx.Rollback(ctx) } - tx.close(nil, nil) + tx.close(txResultRollback, nil, nil) return nil }