Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 94 additions & 49 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,9 @@ type conn struct {
// tempExecOptions can be set by passing it in as an argument to ExecContext or QueryContext
// and are applied only to that statement.
tempExecOptions *ExecOptions
// tempTransactionOptions are temporarily set right before a read/write transaction is started.
tempTransactionOptions *ReadWriteTransactionOptions
// tempReadOnlyTransactionOptions are temporarily set right before a read-only
// transaction is started on a Spanner connection.
tempReadOnlyTransactionOptions *ReadOnlyTransactionOptions
// tempBatchReadOnlyTransactionOptions are temporarily set right before a
// batch read-only transaction is started on a Spanner connection.
tempBatchReadOnlyTransactionOptions *BatchReadOnlyTransactionOptions
// tempTransactionCloseFunc is set right before a transaction is started, and is set as the
// close function for that transaction.
tempTransactionCloseFunc func()
}

func (c *conn) UnderlyingClient() (*spanner.Client, error) {
Expand Down Expand Up @@ -1011,8 +1006,10 @@ func (c *conn) options(reset bool) *ExecOptions {
TransactionTag: c.TransactionTag(),
IsolationLevel: toProtoIsolationLevelOrDefault(c.IsolationLevel()),
ReadLockMode: c.ReadLockMode(),
CommitPriority: propertyCommitPriority.GetValueOrDefault(c.state),
CommitOptions: spanner.CommitOptions{
MaxCommitDelay: c.maxCommitDelayPointer(),
MaxCommitDelay: c.maxCommitDelayPointer(),
ReturnCommitStats: propertyReturnCommitStats.GetValueOrDefault(c.state),
},
},
PartitionedQueryOptions: PartitionedQueryOptions{},
Expand Down Expand Up @@ -1045,16 +1042,43 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
}

func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions) {
c.tempTransactionOptions = options
if options == nil {
return
}
c.tempTransactionCloseFunc = options.close
// Start a transaction for the connection state, so we can set the transaction options
// as local options in the current transaction.
_ = c.state.Begin()
if options.DisableInternalRetries {
_ = propertyRetryAbortsInternally.SetLocalValue(c.state, !options.DisableInternalRetries)
}
if options.TransactionOptions.BeginTransactionOption != spanner.DefaultBeginTransaction {
_ = propertyBeginTransactionOption.SetLocalValue(c.state, options.TransactionOptions.BeginTransactionOption)
}
if options.TransactionOptions.CommitOptions.MaxCommitDelay != nil {
_ = propertyMaxCommitDelay.SetLocalValue(c.state, *options.TransactionOptions.CommitOptions.MaxCommitDelay)
}
if options.TransactionOptions.CommitOptions.ReturnCommitStats {
_ = propertyReturnCommitStats.SetLocalValue(c.state, options.TransactionOptions.CommitOptions.ReturnCommitStats)
}
if options.TransactionOptions.TransactionTag != "" {
_ = propertyTransactionTag.SetLocalValue(c.state, options.TransactionOptions.TransactionTag)
}
if options.TransactionOptions.ReadLockMode != spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED {
_ = propertyReadLockMode.SetLocalValue(c.state, options.TransactionOptions.ReadLockMode)
}
if options.TransactionOptions.IsolationLevel != spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED {
_ = propertyIsolationLevel.SetLocalValue(c.state, toSqlIsolationLevelOrDefault(options.TransactionOptions.IsolationLevel))
}
if options.TransactionOptions.ExcludeTxnFromChangeStreams {
_ = propertyExcludeTxnFromChangeStreams.SetLocalValue(c.state, options.TransactionOptions.ExcludeTxnFromChangeStreams)
}
if options.TransactionOptions.CommitPriority != spannerpb.RequestOptions_PRIORITY_UNSPECIFIED {
_ = propertyCommitPriority.SetLocalValue(c.state, options.TransactionOptions.CommitPriority)
}
}

func (c *conn) getTransactionOptions(execOptions *ExecOptions) ReadWriteTransactionOptions {
if c.tempTransactionOptions != nil {
defer func() { c.tempTransactionOptions = nil }()
opts := *c.tempTransactionOptions
opts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.TransactionOptions.BeginTransactionOption)
return opts
}
txOpts := ReadWriteTransactionOptions{
TransactionOptions: execOptions.TransactionOptions,
DisableInternalRetries: !c.RetryAbortsInternally(),
Expand All @@ -1075,28 +1099,39 @@ func (c *conn) getTransactionOptions(execOptions *ExecOptions) ReadWriteTransact
}

func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOptions) {
c.tempReadOnlyTransactionOptions = options
if options == nil {
return
}
c.tempTransactionCloseFunc = options.close
// Start a transaction for the connection state, so we can set the transaction options
// as local options in the current transaction.
_ = c.state.Begin()
if options.BeginTransactionOption != spanner.DefaultBeginTransaction {
_ = propertyBeginTransactionOption.SetLocalValue(c.state, options.BeginTransactionOption)
}
if options.TimestampBound.String() != "(strong)" {
_ = propertyReadOnlyStaleness.SetLocalValue(c.state, options.TimestampBound)
}
}

func (c *conn) getReadOnlyTransactionOptions() ReadOnlyTransactionOptions {
if c.tempReadOnlyTransactionOptions != nil {
defer func() { c.tempReadOnlyTransactionOptions = nil }()
opts := *c.tempReadOnlyTransactionOptions
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption)
return opts
}
return ReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness(), BeginTransactionOption: c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))}
}

func (c *conn) withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) {
c.tempBatchReadOnlyTransactionOptions = options
if options == nil {
return
}
c.tempTransactionCloseFunc = options.close
// Start a transaction for the connection state, so we can set the transaction options
// as local options in the current transaction.
_ = c.state.Begin()
if options.TimestampBound.String() != "(strong)" {
_ = propertyReadOnlyStaleness.SetLocalValue(c.state, options.TimestampBound)
}
}

func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOptions {
if c.tempBatchReadOnlyTransactionOptions != nil {
defer func() { c.tempBatchReadOnlyTransactionOptions = nil }()
return *c.tempBatchReadOnlyTransactionOptions
}
return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()}
}

Expand All @@ -1108,7 +1143,6 @@ func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTr
c.withTempReadOnlyTransactionOptions(options)
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true})
if err != nil {
c.withTempReadOnlyTransactionOptions(nil)
return nil, err
}
return tx, nil
Expand All @@ -1122,7 +1156,6 @@ func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWrite
c.withTempTransactionOptions(options)
tx, err := c.BeginTx(ctx, driver.TxOptions{})
if err != nil {
c.withTempTransactionOptions(nil)
return nil, err
}
return tx, nil
Expand All @@ -1133,6 +1166,13 @@ func (c *conn) Begin() (driver.Tx, error) {
}

func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver.Tx, error) {
defer func() {
c.tempTransactionCloseFunc = nil
}()
return c.beginTx(ctx, driverOpts, c.tempTransactionCloseFunc)
}

func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFunc func()) (driver.Tx, error) {
if c.resetForRetry {
c.resetForRetry = false
return c.tx, nil
Expand All @@ -1141,6 +1181,10 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver
defer func() {
if c.tx != nil {
_ = c.state.Begin()
} else {
// Rollback in case the connection state transaction was started before this function
// was called, for example if the caller set temporary transaction options.
_ = c.state.Rollback()
}
}()

Expand Down Expand Up @@ -1180,6 +1224,9 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver
if batchReadOnly && !driverOpts.ReadOnly {
return nil, status.Error(codes.InvalidArgument, "levelBatchReadOnly can only be used for read-only transactions")
}
if closeFunc == nil {
closeFunc = func() {}
}

if driverOpts.ReadOnly {
var logger *slog.Logger
Expand All @@ -1188,49 +1235,47 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver
if batchReadOnly {
logger = c.logger.With("tx", "batchro")
var err error
// BatchReadOnly transactions (currently) do not support inline-begin.
// This means that the transaction options must be supplied here, and not through a callback.
bo, err = c.client.BatchReadOnlyTransaction(ctx, batchReadOnlyTxOpts.TimestampBound)
if err != nil {
return nil, err
}
ro = &bo.ReadOnlyTransaction
} else {
logger = c.logger.With("tx", "ro")
ro = c.client.ReadOnlyTransaction().WithBeginTransactionOption(readOnlyTxOpts.BeginTransactionOption).WithTimestampBound(readOnlyTxOpts.TimestampBound)
ro = c.client.ReadOnlyTransaction().WithBeginTransactionOption(readOnlyTxOpts.BeginTransactionOption)
}
c.tx = &readOnlyTransaction{
roTx: ro,
boTx: bo,
logger: logger,
close: func(result txResult) {
if batchReadOnlyTxOpts.close != nil {
batchReadOnlyTxOpts.close()
}
if readOnlyTxOpts.close != nil {
readOnlyTxOpts.close()
}
closeFunc()
if result == txResultCommit {
_ = c.state.Commit()
} else {
_ = c.state.Rollback()
}
c.tx = nil
},
timestampBoundCallback: func() spanner.TimestampBound {
return propertyReadOnlyStaleness.GetValueOrDefault(c.state)
},
}
return c.tx, nil
}

// These options are only used to determine how to start the transaction.
// All other options are fetched in a callback that is called when the transaction is actually started.
// That callback reads all transaction options from the connection state at that moment. This allows
// applications to execute a series of statement like this:
// BEGIN TRANSACTION;
// SET LOCAL transaction_tag='my_tag';
// SET LOCAL commit_priority=LOW;
// INSERT INTO my_table ... -- This starts the transaction with the options above included.
opts := spanner.TransactionOptions{}
if c.tempTransactionOptions != nil {
opts = c.tempTransactionOptions.TransactionOptions
}
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption)
tempCloseFunc := func() {}
if c.tempTransactionOptions != nil && c.tempTransactionOptions.close != nil {
tempCloseFunc = c.tempTransactionOptions.close
}
if !disableRetryAborts && c.tempTransactionOptions != nil {
disableRetryAborts = c.tempTransactionOptions.DisableInternalRetries
}
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))

tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
defer func() {
Expand All @@ -1249,7 +1294,7 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver
logger: logger,
rwTx: tx,
close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
tempCloseFunc()
closeFunc()
c.prevTx = c.tx
c.tx = nil
if commitErr == nil {
Expand Down
21 changes: 21 additions & 0 deletions connection_properties.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,27 @@ var propertyMaxCommitDelay = createConnectionProperty(
connectionstate.ContextUser,
connectionstate.ConvertDuration,
)
var propertyCommitPriority = createConnectionProperty(
"commit_priority",
"Sets the priority for commit RPC invocations from this connection (HIGH/MEDIUM/LOW/UNSPECIFIED). "+
"The default is UNSPECIFIED.",
spannerpb.RequestOptions_PRIORITY_UNSPECIFIED,
false,
nil,
connectionstate.ContextUser,
func(value string) (spannerpb.RequestOptions_Priority, error) {
return parseRpcPriority(value)
},
)
var propertyReturnCommitStats = createConnectionProperty(
"return_commit_stats",
"return_commit_stats determines whether transactions should request Spanner to return commit statistics.",
false,
false,
nil,
connectionstate.ContextUser,
connectionstate.ConvertBool,
)

// ------------------------------------------------------------------------------------------------
// Statement connection properties.
Expand Down
24 changes: 18 additions & 6 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,6 @@ func BeginReadWriteTransaction(ctx context.Context, db *sql.DB, options ReadWrit
}
tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
clearTempReadWriteTransactionOptions(conn)
return nil, err
}
return tx, nil
Expand All @@ -1166,11 +1165,6 @@ func withTempReadWriteTransactionOptions(conn *sql.Conn, options *ReadWriteTrans
})
}

func clearTempReadWriteTransactionOptions(conn *sql.Conn) {
_ = withTempReadWriteTransactionOptions(conn, nil)
_ = conn.Close()
}

// ReadOnlyTransactionOptions can be used to create a read-only transaction
// on a Spanner connection.
type ReadOnlyTransactionOptions struct {
Expand Down Expand Up @@ -1529,6 +1523,24 @@ func toProtoIsolationLevelOrDefault(level sql.IsolationLevel) spannerpb.Transact
return res
}

func toSqlIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) (sql.IsolationLevel, error) {
switch level {
case spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED:
return sql.LevelDefault, nil
case spannerpb.TransactionOptions_SERIALIZABLE:
return sql.LevelSerializable, nil
case spannerpb.TransactionOptions_REPEATABLE_READ:
return sql.LevelRepeatableRead, nil
default:
}
return sql.LevelDefault, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported isolation level: %v", level))
}

func toSqlIsolationLevelOrDefault(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel {
res, _ := toSqlIsolationLevel(level)
return res
}

type spannerIsolationLevel sql.IsolationLevel

const (
Expand Down
2 changes: 1 addition & 1 deletion driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5076,7 +5076,7 @@ func TestBeginReadWriteTransaction(t *testing.T) {
t.Fatalf("missing transaction for ExecuteSqlRequest")
}
if req.Transaction.GetId() == nil {
t.Fatalf("missing begin selector for ExecuteSqlRequest")
t.Fatalf("missing ID selector for ExecuteSqlRequest")
}
if g, w := req.RequestOptions.TransactionTag, tag; g != w {
t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w)
Expand Down
Loading
Loading