diff --git a/README.md b/README.md index f5f07771..11bff8bc 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [Google Cloud Spanner](https://cloud.google.com/spanner) driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package. -``` go +```go import _ "github.com/googleapis/go-sql-spanner" db, err := sql.Open("spanner", "projects/PROJECT/instances/INSTANCE/databases/DATABASE") @@ -65,6 +65,29 @@ the same named query parameter is used in multiple places in the statement. db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374) ``` +### Query Options +Query options can be passed in as arguments to a query. Pass in a value of +type `spannerdriver.ExecOptions` to supply additional execution options for +a statement. The `spanner.QueryOptions` will be passed through to the Spanner +client as the query options to use for the query or DML statement. + +```go +tx.ExecContext(ctx, "INSERT INTO Singers (SingerId, Name) VALUES (@id, @name)", + spannerdriver.ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "insert_singer"}}, + 123, "Bruce Allison") +tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", + spannerdriver.ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "select_singer"}}, + 123) +``` + +Statement tags (request tags) can also be set using the custom SQL statement +`set statement_tag='my_tag'`: + +```go +tx.ExecContext(ctx, "set statement_tag = 'select_singer'") +tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", 123) +``` + ## Transactions - Read-write transactions always uses the strongest isolation level and ignore the user-specified level. @@ -72,9 +95,15 @@ db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374) either Commit or Rollback. Calling either of these methods will end the current read-only transaction and return the session that is used to the session pool. -``` go +```go tx, err := db.BeginTx(ctx, &sql.TxOptions{}) // Read-write transaction. +// Read-write transaction with a transaction tag. +conn, _ := db.Conn(ctx) +_, _ := conn.ExecContext(ctx, "SET TRANSACTION_TAG='my_transaction_tag'") +tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + + tx, err := db.BeginTx(ctx, &sql.TxOptions{ ReadOnly: true, // Read-only transaction using strong reads. }) @@ -91,8 +120,8 @@ tx, err := conn.BeginTx(ctx, &sql.TxOptions{ Spanner can abort a read/write transaction if concurrent modifications are detected that would violate the transaction consistency. When this happens, the driver will return the `ErrAbortedDueToConcurrentModification` error. You can use the -`RunTransaction` function to let the driver automatically retry transactions that -are aborted by Spanner. +`RunTransaction` and `RunTransactionWithOptions` functions to let the driver +automatically retry transactions that are aborted by Spanner. ```go package sample @@ -106,14 +135,14 @@ import ( spannerdriver "github.com/googleapis/go-sql-spanner" ) -spannerdriver.RunTransaction(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { +spannerdriver.RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { row := tx.QueryRowContext(ctx, "select Name from Singers where SingerId=@id", 123) var name string if err := row.Scan(&name); err != nil { return err } return nil -}) +}, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"}) ``` See also the [transaction runner sample](./examples/run-transaction/main.go). diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index 8690df0d..fe0aaa5c 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -49,9 +49,10 @@ type checksumRowIterator struct { *spanner.RowIterator metadata *sppb.ResultSetMetadata - ctx context.Context - tx *readWriteTransaction - stmt spanner.Statement + ctx context.Context + tx *readWriteTransaction + stmt spanner.Statement + options spanner.QueryOptions // nc (nextCount) indicates the number of times that next has been called // on the iterator. Next() will be called the same number of times during // a retry. @@ -160,7 +161,7 @@ func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sp func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { buffer := &bytes.Buffer{} enc := gob.NewEncoder(buffer) - retryIt := tx.Query(ctx, it.stmt) + retryIt := tx.QueryWithOptions(ctx, it.stmt, it.options) // If the original iterator had been stopped, we should also always stop the // new iterator. if it.stopped { diff --git a/client_side_statement.go b/client_side_statement.go index c8134858..693d0087 100644 --- a/client_side_statement.go +++ b/client_side_statement.go @@ -93,6 +93,22 @@ func (s *statementExecutor) ShowExcludeTxnFromChangeStreams(_ context.Context, c return &rows{it: it}, nil } +func (s *statementExecutor) ShowTransactionTag(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) { + it, err := createStringIterator("TransactionTag", c.TransactionTag()) + if err != nil { + return nil, err + } + return &rows{it: it}, nil +} + +func (s *statementExecutor) ShowStatementTag(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) { + it, err := createStringIterator("StatementTag", c.StatementTag()) + if err != nil { + return nil, err + } + return &rows{it: it}, nil +} + func (s *statementExecutor) StartBatchDdl(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Result, error) { return c.startBatchDDL() } @@ -147,6 +163,35 @@ func (s *statementExecutor) SetExcludeTxnFromChangeStreams(_ context.Context, c return c.setExcludeTxnFromChangeStreams(exclude) } +func (s *statementExecutor) SetTransactionTag(_ context.Context, c *conn, params string, _ []driver.NamedValue) (driver.Result, error) { + tag, err := parseTag(params) + if err != nil { + return nil, err + } + return c.setTransactionTag(tag) +} + +func (s *statementExecutor) SetStatementTag(_ context.Context, c *conn, params string, _ []driver.NamedValue) (driver.Result, error) { + tag, err := parseTag(params) + if err != nil { + return nil, err + } + return c.setStatementTag(tag) +} + +func parseTag(params string) (string, error) { + if params == "" { + return "", spanner.ToSpannerError(status.Error(codes.InvalidArgument, "no value given for tag")) + } + tag := strings.TrimSpace(params) + if !(strings.HasPrefix(tag, "'") && strings.HasSuffix(tag, "'")) { + return "", spanner.ToSpannerError(status.Error(codes.InvalidArgument, "missing single quotes around tag")) + } + tag = strings.TrimLeft(tag, "'") + tag = strings.TrimRight(tag, "'") + return tag, nil +} + var strongRegexp = regexp.MustCompile("(?i)'STRONG'") var exactStalenessRegexp = regexp.MustCompile(`(?i)'(?PEXACT_STALENESS)[\t ]+(?P(\d{1,19})(s|ms|us|ns))'`) var maxStalenessRegexp = regexp.MustCompile(`(?i)'(?PMAX_STALENESS)[\t ]+(?P(\d{1,19})(s|ms|us|ns))'`) diff --git a/client_side_statement_test.go b/client_side_statement_test.go index 29085d75..4c0e5d5a 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -374,3 +374,75 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) { } } } + +func TestStatementExecutor_SetTransactionTag(t *testing.T) { + ctx := context.Background() + for i, test := range []struct { + wantValue string + setValue string + wantSetErr bool + }{ + {"test-tag", "'test-tag'", false}, + {"other-tag", " 'other-tag'\t\n", false}, + {" tag with spaces ", "' tag with spaces '", false}, + {"", "tag-without-quotes", true}, + {"", "tag-with-missing-opening-quote'", true}, + {"", "'tag-with-missing-closing-quote", true}, + } { + c := &conn{retryAborts: true} + s := &statementExecutor{} + + it, err := s.ShowTransactionTag(ctx, c, "", nil) + if err != nil { + t.Fatalf("%d: could not get current transaction tag value from connection: %v", i, err) + } + cols := it.Columns() + wantCols := []string{"TransactionTag"} + if !cmp.Equal(cols, wantCols) { + t.Fatalf("%d: column names mismatch\nGot: %v\nWant: %v", i, cols, wantCols) + } + values := make([]driver.Value, len(cols)) + if err := it.Next(values); err != nil { + t.Fatalf("%d: failed to get first row: %v", i, err) + } + wantValues := []driver.Value{""} + if !cmp.Equal(values, wantValues) { + t.Fatalf("%d: default transaction tag mismatch\nGot: %v\nWant: %v", i, values, wantValues) + } + if err := it.Next(values); err != io.EOF { + t.Fatalf("%d: error mismatch\nGot: %v\nWant: %v", i, err, io.EOF) + } + + // Set a transaction tag. + res, err := s.SetTransactionTag(ctx, c, test.setValue, nil) + if test.wantSetErr { + if err == nil { + t.Fatalf("%d: missing expected error for value %q", i, test.setValue) + } + } else { + if err != nil { + t.Fatalf("%d: could not set new value %q for exclude: %v", i, test.setValue, err) + } + if res != driver.ResultNoRows { + t.Fatalf("%d: result mismatch\nGot: %v\nWant: %v", i, res, driver.ResultNoRows) + } + } + + // Get the tag that was set + it, err = s.ShowTransactionTag(ctx, c, "", nil) + if err != nil { + t.Fatalf("%d: could not get current transaction tag value from connection: %v", i, err) + } + if err := it.Next(values); err != nil { + t.Fatalf("%d: failed to get first row: %v", i, err) + } + wantValues = []driver.Value{test.wantValue} + if !cmp.Equal(values, wantValues) { + t.Fatalf("%d: transaction tag mismatch\nGot: %v\nWant: %v", i, values, wantValues) + } + if err := it.Next(values); err != io.EOF { + t.Fatalf("%d: error mismatch\nGot: %v\nWant: %v", i, err, io.EOF) + } + + } +} diff --git a/client_side_statements_json.go b/client_side_statements_json.go index ad848391..7c047888 100644 --- a/client_side_statements_json.go +++ b/client_side_statements_json.go @@ -60,6 +60,24 @@ var jsonFile = `{ "method": "statementShowExcludeTxnFromChangeStreams", "exampleStatements": ["show variable exclude_txn_from_change_streams"] }, + { + "name": "SHOW VARIABLE TRANSACTION_TAG", + "executorName": "ClientSideStatementNoParamExecutor", + "resultType": "RESULT_SET", + "statementType": "SHOW_TRANSACTION_TAG", + "regex": "(?is)\\A\\s*show\\s+variable\\s+transaction_tag\\s*\\z", + "method": "statementShowTransactionTag", + "exampleStatements": ["show variable transaction_tag"] + }, + { + "name": "SHOW VARIABLE STATEMENT_TAG", + "executorName": "ClientSideStatementNoParamExecutor", + "resultType": "RESULT_SET", + "statementType": "SHOW_STATEMENT_TAG", + "regex": "(?is)\\A\\s*show\\s+variable\\s+statement_tag\\s*\\z", + "method": "statementShowStatementTag", + "exampleStatements": ["show variable statement_tag"] + }, { "name": "START BATCH DDL", "executorName": "ClientSideStatementNoParamExecutor", @@ -165,6 +183,36 @@ var jsonFile = `{ "allowedValues": "(TRUE|FALSE)", "converterName": "ClientSideStatementValueConverters$BooleanConverter" } + }, + { + "name": "SET TRANSACTION_TAG = ''", + "executorName": "ClientSideStatementSetExecutor", + "resultType": "NO_RESULT", + "statementType": "SET_TRANSACTION_TAG", + "regex": "(?is)\\A\\s*set\\s+transaction_tag\\s*(?:=)\\s*(.*)\\z", + "method": "statementSetTransactionTag", + "exampleStatements": ["set transaction_tag='tag1'", "set transaction_tag='tag2'", "set transaction_tag=''", "set transaction_tag='test_tag'"], + "setStatement": { + "propertyName": "TRANSACTION_TAG", + "separator": "=", + "allowedValues": "'(.*)'", + "converterName": "ClientSideStatementValueConverters$StringValueConverter" + } + }, + { + "name": "SET STATEMENT_TAG = ''", + "executorName": "ClientSideStatementSetExecutor", + "resultType": "NO_RESULT", + "statementType": "SET_STATEMENT_TAG", + "regex": "(?is)\\A\\s*set\\s+statement_tag\\s*(?:=)\\s*(.*)\\z", + "method": "statementSetStatementTag", + "exampleStatements": ["set statement_tag='tag1'", "set statement_tag='tag2'", "set statement_tag=''", "set statement_tag='test_tag'"], + "setStatement": { + "propertyName": "STATEMENT_TAG", + "separator": "=", + "allowedValues": "'(.*)'", + "converterName": "ClientSideStatementValueConverters$StringValueConverter" + } } ] } diff --git a/driver.go b/driver.go index 413a3c85..69460da1 100644 --- a/driver.go +++ b/driver.go @@ -85,6 +85,12 @@ func init() { type ExecOptions struct { // DecodeOption indicates how the returned rows should be decoded. DecodeOption DecodeOption + + // TransactionOptions are the transaction options that will be used for + // the transaction that is started by the statement. + TransactionOptions spanner.TransactionOptions + // QueryOptions are the query options that will be used for the statement. + QueryOptions spanner.QueryOptions } type DecodeOption int @@ -453,6 +459,32 @@ func (c *connector) closeClients() (err error) { // // This function will never return ErrAbortedDueToConcurrentModification. func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error) error { + return runTransactionWithOptions(ctx, db, opts, f, spanner.TransactionOptions{}) +} + +// RunTransactionWithOptions runs the given function in a transaction on the given database. +// If the connection is a connection to a Spanner database, the transaction will +// automatically be retried if the transaction is aborted by Spanner. Any other +// errors will be propagated to the caller and the transaction will be rolled +// back. The transaction will be committed if the supplied function did not +// return an error. +// +// If the connection is to a non-Spanner database, no retries will be attempted, +// and any error that occurs during the transaction will be propagated to the +// caller. +// +// The application should *NOT* call tx.Commit() or tx.Rollback(). This is done +// automatically by this function, depending on whether the transaction function +// returned an error or not. +// +// The given spanner.TransactionOptions will be used for the transaction. +// +// This function will never return ErrAbortedDueToConcurrentModification. +func RunTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) error { + return runTransactionWithOptions(ctx, db, opts, f, spannerOptions) +} + +func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) error { // Get a connection from the pool that we can use to run a transaction. // Getting a connection here already makes sure that we can reserve this // connection exclusively for the duration of this method. That again @@ -478,6 +510,7 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func // It is not a Spanner connection, so just ignore and continue without any special handling. return nil } + spannerConn.withTransactionOptions(spannerOptions) origRetryAborts = spannerConn.RetryAbortsInternally() return spannerConn.SetRetryAbortsInternally(false) }); err != nil { @@ -614,6 +647,15 @@ type SpannerConn interface { // mode and for read-only transaction. SetReadOnlyStaleness(staleness spanner.TimestampBound) error + // TransactionTag returns the transaction tag that will be applied to the next + // read/write transaction on this connection. The transaction tag that is set + // on the connection is cleared when a read/write transaction is started. + TransactionTag() string + // SetTransactionTag sets the transaction tag that should be applied to the + // next read/write transaction on this connection. The tag is cleared when a + // read/write transaction is started. + SetTransactionTag(transactionTag string) error + // ExcludeTxnFromChangeStreams returns true if the next transaction should be excluded from change streams with the // DDL option `allow_txn_exclusion=true`. ExcludeTxnFromChangeStreams() bool @@ -650,8 +692,14 @@ type SpannerConn interface { // has not been aborted is not supported and will cause an error to be // returned. resetTransactionForRetry(ctx context.Context, errDuringCommit bool) error + + // setTransactionOptions sets the TransactionOptions that should be used + // for this transaction. + withTransactionOptions(options spanner.TransactionOptions) } +var _ SpannerConn = &conn{} + type conn struct { connector *connector closed bool @@ -664,10 +712,10 @@ type conn struct { database string retryAborts bool - execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator - execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (rowIterator, time.Time, error) - execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error) - execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) + execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options ExecOptions) *spanner.RowIterator + execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, time.Time, error) + execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, time.Time, error) + execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) // batch is the currently active DDL or DML batch on this connection. batch *batch @@ -679,11 +727,9 @@ type conn struct { autocommitDMLMode AutocommitDMLMode // readOnlyStaleness is used for queries in autocommit mode and for read-only transactions. readOnlyStaleness spanner.TimestampBound - // excludeTxnFromChangeStreams is used to exclude the next transaction from change streams with the DDL option - // `allow_txn_exclusion=true` - excludeTxnFromChangeStreams bool + // execOptions are applied to the next statement that is executed on this connection. - // It can only be set by passing it in as an argument to ExecContext or QueryContext + // It can be set by passing it in as an argument to ExecContext or QueryContext // and is cleared after each execution. execOptions ExecOptions } @@ -698,6 +744,7 @@ const ( type batch struct { tp batchType statements []spanner.Statement + options ExecOptions } // AutocommitDMLMode indicates whether a single DML statement should be executed @@ -777,7 +824,7 @@ func (c *conn) setReadOnlyStaleness(staleness spanner.TimestampBound) (driver.Re } func (c *conn) ExcludeTxnFromChangeStreams() bool { - return c.excludeTxnFromChangeStreams + return c.execOptions.TransactionOptions.ExcludeTxnFromChangeStreams } func (c *conn) SetExcludeTxnFromChangeStreams(excludeTxnFromChangeStreams bool) error { @@ -789,7 +836,38 @@ func (c *conn) setExcludeTxnFromChangeStreams(excludeTxnFromChangeStreams bool) if c.inTransaction() { return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "cannot set ExcludeTxnFromChangeStreams while a transaction is active")) } - c.excludeTxnFromChangeStreams = excludeTxnFromChangeStreams + c.execOptions.TransactionOptions.ExcludeTxnFromChangeStreams = excludeTxnFromChangeStreams + return driver.ResultNoRows, nil +} + +func (c *conn) TransactionTag() string { + return c.execOptions.TransactionOptions.TransactionTag +} + +func (c *conn) SetTransactionTag(transactionTag string) error { + _, err := c.setTransactionTag(transactionTag) + return err +} + +func (c *conn) setTransactionTag(transactionTag string) (driver.Result, error) { + if c.inTransaction() { + return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "cannot set transaction tag while a transaction is active")) + } + c.execOptions.TransactionOptions.TransactionTag = transactionTag + return driver.ResultNoRows, nil +} + +func (c *conn) StatementTag() string { + return c.execOptions.QueryOptions.RequestTag +} + +func (c *conn) SetStatementTag(statementTag string) error { + _, err := c.setStatementTag(statementTag) + return err +} + +func (c *conn) setStatementTag(statementTag string) (driver.Result, error) { + c.execOptions.QueryOptions.RequestTag = statementTag return driver.ResultNoRows, nil } @@ -844,8 +922,10 @@ func (c *conn) startBatchDDL() (driver.Result, error) { } func (c *conn) startBatchDML() (driver.Result, error) { + execOptions := c.options() + if c.inTransaction() { - return c.tx.StartBatchDML() + return c.tx.StartBatchDML(execOptions.QueryOptions) } if c.batch != nil { @@ -854,7 +934,7 @@ func (c *conn) startBatchDML() (driver.Result, error) { if c.inReadOnlyTransaction() { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "This connection has an active read-only transaction. Read-only transactions cannot execute DML batches.")) } - c.batch = &batch{tp: dml} + c.batch = &batch{tp: dml, options: execOptions} return driver.ResultNoRows, nil } @@ -884,8 +964,9 @@ func (c *conn) runDDLBatch(ctx context.Context) (driver.Result, error) { func (c *conn) runDMLBatch(ctx context.Context) (driver.Result, error) { statements := c.batch.statements + options := c.batch.options c.batch = nil - return c.execBatchDML(ctx, statements) + return c.execBatchDML(ctx, statements, options) } func (c *conn) abortBatch() (driver.Result, error) { @@ -925,7 +1006,7 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr return driver.ResultNoRows, nil } -func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement) (driver.Result, error) { +func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement, options ExecOptions) (driver.Result, error) { if len(statements) == 0 { return &result{}, nil } @@ -937,12 +1018,12 @@ func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement) if !ok { return nil, status.Errorf(codes.FailedPrecondition, "connection is in a transaction that is not a read/write transaction") } - affected, err = tx.rwTx.BatchUpdate(ctx, statements) + affected, err = tx.rwTx.BatchUpdateWithOptions(ctx, statements, options.QueryOptions) } else { _, err = c.client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, transaction *spanner.ReadWriteTransaction) error { - affected, err = transaction.BatchUpdate(ctx, statements) + affected, err = transaction.BatchUpdateWithOptions(ctx, statements, options.QueryOptions) return err - }, c.createTransactionOptions()) + }, options.TransactionOptions) } return &result{rowsAffected: sum(affected)}, err } @@ -1142,16 +1223,15 @@ func (c *conn) Prepare(query string) (driver.Stmt, error) { } func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + execOptions := c.options() parsedSQL, args, err := parseParameters(query) if err != nil { return nil, err } - return &stmt{conn: c, query: parsedSQL, numArgs: len(args)}, nil + return &stmt{conn: c, query: parsedSQL, numArgs: len(args), execOptions: execOptions}, nil } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - execOptions := c.options() - // Execute client side statement if it is one. clientStmt, err := parseClientSideStatement(c, query) if err != nil { @@ -1160,6 +1240,8 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam if clientStmt != nil { return clientStmt.QueryContext(ctx, args) } + + execOptions := c.options() // Clear the commit timestamp of this connection before we execute the query. c.commitTs = nil @@ -1173,7 +1255,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam if statementType == statementTypeDml { // Use a read/write transaction to execute the statement. var commitTs time.Time - iter, commitTs, err = c.execSingleQueryTransactional(ctx, c.client, stmt, c.createTransactionOptions()) + iter, commitTs, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions) if err != nil { return nil, err } @@ -1182,18 +1264,15 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam // The statement was either detected as being a query, or potentially not recognized at all. // In that case, just default to using a single-use read-only transaction and let Spanner // return an error if the statement is not suited for that type of transaction. - iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)} + iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness, execOptions)} } } else { - iter = c.tx.Query(ctx, stmt) + iter = c.tx.Query(ctx, stmt, execOptions.QueryOptions) } return &rows{it: iter, decodeOption: execOptions.DecodeOption}, nil } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - // Make sure options are reset after calling this method. - _ = c.options() - // Execute client side statement if it is one. stmt, err := parseClientSideStatement(c, query) if err != nil { @@ -1202,6 +1281,8 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name if stmt != nil { return stmt.ExecContext(ctx, args) } + execOptions := c.options() + // Clear the commit timestamp of this connection before we execute the statement. c.commitTs = nil @@ -1229,18 +1310,18 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name c.batch.statements = append(c.batch.statements, ss) } else { if c.autocommitDMLMode == Transactional { - rowsAffected, commitTs, err = c.execSingleDMLTransactional(ctx, c.client, ss, c.createTransactionOptions()) + rowsAffected, commitTs, err = c.execSingleDMLTransactional(ctx, c.client, ss, execOptions) if err == nil { c.commitTs = &commitTs } } else if c.autocommitDMLMode == PartitionedNonAtomic { - rowsAffected, err = c.execSingleDMLPartitioned(ctx, c.client, ss, c.createPartitionedDmlQueryOptions()) + rowsAffected, err = c.execSingleDMLPartitioned(ctx, c.client, ss, execOptions) } else { return nil, status.Errorf(codes.FailedPrecondition, "connection in invalid state for DML statements: %s", c.autocommitDMLMode.String()) } } } else { - rowsAffected, err = c.tx.ExecContext(ctx, ss) + rowsAffected, err = c.tx.ExecContext(ctx, ss, execOptions.QueryOptions) } if err != nil { return nil, err @@ -1275,6 +1356,10 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo return c.tx.resetForRetry(ctx) } +func (c *conn) withTransactionOptions(options spanner.TransactionOptions) { + c.execOptions.TransactionOptions = options +} + type spannerIsolationLevel sql.IsolationLevel const ( @@ -1304,6 +1389,8 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e if c.inBatch() { return nil, status.Error(codes.FailedPrecondition, "This connection has an active batch. Run or abort the batch before starting a new transaction.") } + + execOptions := c.options() disableRetryAborts := false sil := opts.Isolation >> 8 opts.Isolation = opts.Isolation - sil @@ -1325,8 +1412,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return c.tx, nil } - options := c.createTransactionOptions() - tx, err := spanner.NewReadWriteStmtBasedTransactionWithOptions(ctx, c.client, options) + tx, err := spanner.NewReadWriteStmtBasedTransactionWithOptions(ctx, c.client, execOptions.TransactionOptions) if err != nil { return nil, err } @@ -1367,8 +1453,8 @@ func (c *conn) inReadWriteTransaction() bool { return false } -func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { - return c.Single().WithTimestampBound(tb).Query(ctx, statement) +func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { + return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions) } type wrappedRowIterator struct { @@ -1397,10 +1483,10 @@ func (ri *wrappedRowIterator) Metadata() *spannerpb.ResultSetMetadata { return ri.RowIterator.Metadata } -func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (rowIterator, time.Time, error) { +func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, time.Time, error) { var result *wrappedRowIterator fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { - it := tx.Query(ctx, statement) + it := tx.QueryWithOptions(ctx, statement, options.QueryOptions) row, err := it.Next() if err == iterator.Done { result = &wrappedRowIterator{ @@ -1418,39 +1504,31 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s } return nil } - resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options) + resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options.TransactionOptions) if err != nil { return nil, time.Time{}, err } return result, resp.CommitTs, nil } -func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { +func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, time.Time, error) { var rowsAffected int64 fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { - count, err := tx.Update(ctx, statement) + count, err := tx.UpdateWithOptions(ctx, statement, options.QueryOptions) rowsAffected = count return err } - resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options) + resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options.TransactionOptions) if err != nil { return 0, time.Time{}, err } return rowsAffected, resp.CommitTs, nil } -func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { - return c.PartitionedUpdateWithOptions(ctx, statement, options) -} - -func (c *conn) createTransactionOptions() spanner.TransactionOptions { - defer func() { c.excludeTxnFromChangeStreams = false }() - return spanner.TransactionOptions{ExcludeTxnFromChangeStreams: c.excludeTxnFromChangeStreams} -} - -func (c *conn) createPartitionedDmlQueryOptions() spanner.QueryOptions { - defer func() { c.excludeTxnFromChangeStreams = false }() - return spanner.QueryOptions{ExcludeTxnFromChangeStreams: c.excludeTxnFromChangeStreams} +func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) { + queryOptions := options.QueryOptions + queryOptions.ExcludeTxnFromChangeStreams = options.TransactionOptions.ExcludeTxnFromChangeStreams + return c.PartitionedUpdateWithOptions(ctx, statement, queryOptions) } /* The following is the same implementation as in google-cloud-go/spanner */ diff --git a/driver_test.go b/driver_test.go index 97d401d5..e8711133 100644 --- a/driver_test.go +++ b/driver_test.go @@ -363,13 +363,13 @@ func TestConn_StartBatchDml(t *testing.T) { func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) { c := &conn{ batch: &batch{tp: ddl}, - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, time.Time, error) { return 0, time.Time{}, nil }, - execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { + execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) { return 0, nil }, } @@ -396,13 +396,13 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) { func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) { c := &conn{ batch: &batch{tp: dml}, - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, time.Time, error) { return 0, time.Time{}, nil }, - execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { + execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) { return 0, nil }, } @@ -476,13 +476,13 @@ func TestConn_GetBatchedStatements(t *testing.T) { func TestConn_GetCommitTimestampAfterAutocommitDml(t *testing.T) { want := time.Now() c := &conn{ - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, time.Time, error) { return 0, want, nil }, - execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { + execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) { return 0, nil }, } @@ -501,13 +501,13 @@ func TestConn_GetCommitTimestampAfterAutocommitDml(t *testing.T) { func TestConn_GetCommitTimestampAfterAutocommitQuery(t *testing.T) { c := &conn{ - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, time.Time, error) { return 0, time.Time{}, nil }, - execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { + execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error) { return 0, nil }, } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 90dc216b..5eac7559 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -2504,10 +2504,18 @@ func TestExcludeTxnFromChangeStreams_AutoCommitBatchDml(t *testing.T) { t.Fatalf("failed to get a connection: %v", err) } - _, _ = conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true") - _, _ = conn.ExecContext(ctx, "start batch dml") - _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) - _, _ = conn.ExecContext(ctx, "run batch") + if _, err := conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "start batch dml"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "run batch"); err != nil { + t.Fatal(err) + } requests := drainRequestsFromServer(server.TestSpanner) batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) if g, w := len(batchRequests), 1; g != w { @@ -2543,9 +2551,15 @@ func TestExcludeTxnFromChangeStreams_PartitionedDml(t *testing.T) { t.Fatalf("failed to get a connection: %v", err) } - conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true") - conn.ExecContext(ctx, "set autocommit_dml_mode = 'partitioned_non_atomic'") - conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + if _, err := conn.ExecContext(ctx, "set exclude_txn_from_change_streams = true"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "set autocommit_dml_mode = 'partitioned_non_atomic'"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } requests := drainRequestsFromServer(server.TestSpanner) beginRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{})) if g, w := len(beginRequests), 1; g != w { @@ -2608,6 +2622,531 @@ func TestExcludeTxnFromChangeStreams_Transaction(t *testing.T) { } } +func TestTag_Query_AutoCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + _, _ = conn.ExecContext(ctx, "set statement_tag = 'tag_1'") + iter, err := conn.QueryContext(ctx, testutil.SelectFooFromBar) + if err != nil { + t.Fatalf("failed to execute query: %v", err) + } + // Just consume the results to ensure that the query is executed. + for iter.Next() { + if iter.Err() != nil { + t.Fatal(iter.Err()) + } + } + iter.Close() + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 1; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + execRequest := execRequests[0].(*sppb.ExecuteSqlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, "tag_1"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the statement. + var statementTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE STATEMENT_TAG").Scan(&statementTag); err != nil { + t.Fatalf("failed to get statement_tag: %v", err) + } + if g, w := statementTag, ""; g != w { + t.Fatalf("statement_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_Update_AutoCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + _, _ = conn.ExecContext(ctx, "set statement_tag = 'tag_1'") + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 1; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + execRequest := execRequests[0].(*sppb.ExecuteSqlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, "tag_1"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequest := commitRequests[0].(*sppb.CommitRequest) + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the statement. + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_AutoCommit_BatchDml(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + _, _ = conn.ExecContext(ctx, "set statement_tag = 'tag_1'") + _, _ = conn.ExecContext(ctx, "start batch dml") + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = conn.ExecContext(ctx, "run batch") + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteBatchDmlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(execRequests), 1; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + execRequest := execRequests[0].(*sppb.ExecuteBatchDmlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, "tag_1"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequest := commitRequests[0].(*sppb.CommitRequest) + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the statement. + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_ReadWriteTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_1'") + rows, _ := tx.QueryContext(ctx, testutil.SelectFooFromBar) + for rows.Next() { + } + rows.Close() + + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_2'") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_3'") + _, _ = tx.ExecContext(ctx, "start batch dml") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = tx.ExecContext(ctx, "run batch") + _ = tx.Commit() + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 2; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(execRequests); i++ { + execRequest := execRequests[i].(*sppb.ExecuteSqlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, fmt.Sprintf("tag_%d", (i%2)+1); g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("number of batch request mismatch\n Got: %v\nWant: %v", g, w) + } + batchRequest := batchRequests[0].(*sppb.ExecuteBatchDmlRequest) + if g, w := batchRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := batchRequest.RequestOptions.RequestTag, "tag_3"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequest := commitRequests[0].(*sppb.CommitRequest) + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the transaction. + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_ReadWriteTransaction_Retry(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + for _, useArgs := range []bool{false, true} { + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + + var rows *sql.Rows + if useArgs { + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_1"}}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_1'") + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar) + } + for rows.Next() { + } + rows.Close() + + if useArgs { + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_2"}}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_2'") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + } + + if useArgs { + _, _ = tx.ExecContext(ctx, "start batch dml", ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_3"}}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_3'") + _, _ = tx.ExecContext(ctx, "start batch dml") + } + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = tx.ExecContext(ctx, "run batch") + + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + _ = tx.Commit() + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 4; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(execRequests); i++ { + execRequest := execRequests[i].(*sppb.ExecuteSqlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 2 { + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := execRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := execRequest.RequestOptions.RequestTag, fmt.Sprintf("tag_%d", (i%2)+1); g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 2; g != w { + t.Fatalf("number of batch request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(batchRequests); i++ { + batchRequest := batchRequests[i].(*sppb.ExecuteBatchDmlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := batchRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := batchRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := batchRequest.RequestOptions.RequestTag, "tag_3"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 2; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(commitRequests); i++ { + commitRequest := commitRequests[i].(*sppb.CommitRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := commitRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + } + + // Verify that the tag is reset after the transaction. + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestTag_RunTransaction_Retry(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + for _, useArgs := range []bool{false, true} { + attempts := 0 + err = RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + attempts++ + var rows *sql.Rows + if useArgs { + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_1"}}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_1'") + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar) + } + for rows.Next() { + } + rows.Close() + + if useArgs { + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_2"}}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_2'") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + } + + if useArgs { + _, _ = tx.ExecContext(ctx, "start batch dml", ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_3"}}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_3'") + _, _ = tx.ExecContext(ctx, "start batch dml") + } + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = tx.ExecContext(ctx, "run batch") + if attempts == 1 { + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + } + return nil + }, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"}) + if err != nil { + t.Fatalf("failed to run transaction: %v", err) + } + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 4; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(execRequests); i++ { + execRequest := execRequests[i].(*sppb.ExecuteSqlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 2 { + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := execRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := execRequest.RequestOptions.RequestTag, fmt.Sprintf("tag_%d", (i%2)+1); g != w { + t.Fatalf("useArgs: %v, statement tag mismatch\n Got: %v\nWant: %v", useArgs, g, w) + } + } + + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 2; g != w { + t.Fatalf("number of batch request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(batchRequests); i++ { + batchRequest := batchRequests[i].(*sppb.ExecuteBatchDmlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := batchRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := batchRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := batchRequest.RequestOptions.RequestTag, "tag_3"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 2; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(commitRequests); i++ { + commitRequest := commitRequests[i].(*sppb.CommitRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := commitRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + } + + // Verify that the transaction tag is reset after the transaction. + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + func TestMaxIdleConnectionsNonZero(t *testing.T) { t.Parallel() diff --git a/examples/run-transaction/main.go b/examples/run-transaction/main.go index c9c2a327..c423e2a9 100644 --- a/examples/run-transaction/main.go +++ b/examples/run-transaction/main.go @@ -20,6 +20,7 @@ import ( "fmt" "sync" + "cloud.google.com/go/spanner" spannerdriver "github.com/googleapis/go-sql-spanner" "github.com/googleapis/go-sql-spanner/examples" ) @@ -59,7 +60,7 @@ func runTransaction(projectId, instanceId, databaseId string) error { // will be aborted and retried by Spanner multiple times. The end result // will still be that all transactions succeed and the name contains all // indexes in an undefined order. - errors[index] = spannerdriver.RunTransaction(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + errors[index] = spannerdriver.RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { // Query the singer in the transaction. This will take a lock on the row and guarantee that // the value that we read is still the same when the transaction is committed. If not, Spanner // will abort the transaction, and the transaction will be retried. @@ -82,7 +83,7 @@ func runTransaction(projectId, instanceId, databaseId string) error { return fmt.Errorf("unexpected affected row count: %d", affected) } return nil - }) + }, spanner.TransactionOptions{TransactionTag: "sample_transaction"}) }() } wg.Wait() diff --git a/examples/tags/main.go b/examples/tags/main.go new file mode 100644 index 00000000..132b30f5 --- /dev/null +++ b/examples/tags/main.go @@ -0,0 +1,175 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "database/sql" + "fmt" + + "cloud.google.com/go/spanner" + spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/examples" +) + +var createTableStatement = "CREATE TABLE Singers (SingerId INT64, Name STRING(MAX)) PRIMARY KEY (SingerId)" + +// Example for using transaction tags and statement tags through SQL statements. +// +// Tags can also be set programmatically using spannerdriver.RunTransactionWithOptions +// and the spannerdriver.ExecOptions. +// +// Execute the sample with the command `go run main.go` from this directory. +func tagsWithSqlStatements(projectId, instanceId, databaseId string) error { + fmt.Println("Running sample for setting tags with SQL statements") + + ctx := context.Background() + db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId)) + if err != nil { + return err + } + defer db.Close() + + // Obtain a connection for the database in order to ensure that we + // set the transaction tag on the same connection as the connection + // that will execute the transaction. + conn, err := db.Conn(ctx) + if err != nil { + return err + } + + // Set a transaction tag on the connection and start a transaction. + // This transaction tag will be applied to the next transaction that + // is executed by this connection. The transaction tag is automatically + // included with all statements of that transaction. + if _, err := conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'"); err != nil { + return err + } + fmt.Println("Executing transaction with transaction tag 'my_transaction_tag'") + tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + + // Set a statement tag and insert a new record using the transaction that we just started. + if _, err := tx.ExecContext(ctx, "set statement_tag = 'insert_singer'"); err != nil { + _ = tx.Rollback() + return err + } + fmt.Println("Executing statement with tag 'insert_singer'") + _, err = tx.ExecContext(ctx, "INSERT INTO Singers (SingerId, Name) VALUES (@id, @name)", 123, "Bruce Allison") + if err != nil { + _ = tx.Rollback() + return err + } + + // Set another statement tag and execute a query. + if _, err := tx.ExecContext(ctx, "set statement_tag = 'select_singer'"); err != nil { + _ = tx.Rollback() + return err + } + fmt.Println("Executing statement with tag 'select_singer'") + rows, err := tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", 123) + if err != nil { + _ = tx.Rollback() + return err + } + var ( + id int64 + name string + ) + for rows.Next() { + if err := rows.Scan(&id, &name); err != nil { + _ = tx.Rollback() + return err + } + fmt.Printf("Found singer: %v %v\n", id, name) + } + if err := rows.Err(); err != nil { + _ = tx.Rollback() + return err + } + _ = rows.Close() + if err := tx.Commit(); err != nil { + return err + } + + fmt.Println("Finished transaction with tag 'my_transaction_tag'") + + return nil +} + +// tagsProgrammatically shows how to set transaction tags and statement tags +// programmatically. +// +// Note: It is not recommended to mix using SQL statements and passing in +// tags or other QueryOptions programmatically. +func tagsProgrammatically(projectId, instanceId, databaseId string) error { + fmt.Println("Running sample for setting tags programmatically") + + ctx := context.Background() + db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId)) + if err != nil { + return err + } + defer db.Close() + + // Use RunTransactionWithOptions to set a transaction tag programmatically. + fmt.Println("Executing transaction with transaction tag 'my_transaction_tag'") + if err := spannerdriver.RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + fmt.Println("Executing statement with tag 'insert_singer'") + // Pass in a value of spanner.QueryOptions to specify the options that should be used for a DML statement. + _, err = tx.ExecContext(ctx, "INSERT INTO Singers (SingerId, Name) VALUES (@id, @name)", spannerdriver.ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "insert_singer"}}, 123, "Bruce Allison") + if err != nil { + return err + } + + fmt.Println("Executing statement with tag 'select_singer'") + // Pass in a value of spanner.QueryOptions to specify the options that should be used for a query. + rows, err := tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", spannerdriver.ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "select_singer"}}, 123) + if err != nil { + return err + } + var ( + id int64 + name string + ) + for rows.Next() { + if err := rows.Scan(&id, &name); err != nil { + _ = tx.Rollback() + return err + } + fmt.Printf("Found singer: %v %v\n", id, name) + } + if err := rows.Err(); err != nil { + return err + } + _ = rows.Close() + + return nil + }, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"}); err != nil { + return err + } + + fmt.Println("Finished transaction with tag 'my_transaction_tag'") + + return nil +} + +func main() { + examples.RunSampleOnEmulator(tagsWithSqlStatements, createTableStatement) + fmt.Println() + examples.RunSampleOnEmulator(tagsProgrammatically, createTableStatement) +} diff --git a/stmt.go b/stmt.go index 7403f6af..8b1555b3 100644 --- a/stmt.go +++ b/stmt.go @@ -59,14 +59,14 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv var it rowIterator if s.conn.tx != nil { - it = s.conn.tx.Query(ctx, ss) + it = s.conn.tx.Query(ctx, ss, s.execOptions.QueryOptions) } else { if s.statementType == statementTypeUnknown { s.statementType = detectStatementType(s.query) } if s.statementType == statementTypeDml { // Use a read/write transaction to execute the statement. - it, _, err = s.conn.execSingleQueryTransactional(ctx, s.conn.client, ss, s.conn.createTransactionOptions()) + it, _, err = s.conn.execSingleQueryTransactional(ctx, s.conn.client, ss, s.execOptions) if err != nil { return nil, err } @@ -74,7 +74,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv // The statement was either detected as being a query, or potentially not recognized at all. // In that case, just default to using a single-use read-only transaction and let Spanner // return an error if the statement is not suited for that type of transaction. - it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).Query(ctx, ss)} + it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).QueryWithOptions(ctx, ss, s.execOptions.QueryOptions)} } } return &rows{it: it, decodeOption: s.execOptions.DecodeOption}, nil diff --git a/transaction.go b/transaction.go index 465024a5..bc11e013 100644 --- a/transaction.go +++ b/transaction.go @@ -33,10 +33,10 @@ type contextTransaction interface { Commit() error Rollback() error resetForRetry(ctx context.Context) error - Query(ctx context.Context, stmt spanner.Statement) rowIterator - ExecContext(ctx context.Context, stmt spanner.Statement) (int64, error) + Query(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) rowIterator + ExecContext(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) (int64, error) - StartBatchDML() (driver.Result, error) + StartBatchDML(options spanner.QueryOptions) (driver.Result, error) RunBatch(ctx context.Context) (driver.Result, error) AbortBatch() (driver.Result, error) @@ -95,15 +95,15 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error { return nil } -func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement) rowIterator { - return &readOnlyRowIterator{tx.roTx.Query(ctx, stmt)} +func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) rowIterator { + return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, options)} } -func (tx *readOnlyTransaction) ExecContext(_ context.Context, stmt spanner.Statement) (int64, error) { +func (tx *readOnlyTransaction) ExecContext(_ context.Context, _ spanner.Statement, _ spanner.QueryOptions) (int64, error) { return 0, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "read-only transactions cannot write")) } -func (tx *readOnlyTransaction) StartBatchDML() (driver.Result, error) { +func (tx *readOnlyTransaction) StartBatchDML(_ spanner.QueryOptions) (driver.Result, error) { return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "read-only transactions cannot write")) } @@ -175,7 +175,8 @@ type retriableStatement interface { // retriableUpdate implements retriableStatement for update statements. type retriableUpdate struct { // stmt is the statement that was executed on Spanner. - stmt spanner.Statement + stmt spanner.Statement + options spanner.QueryOptions // c is the record count that was returned by Spanner. c int64 // err is the error that was returned by Spanner. @@ -186,7 +187,7 @@ type retriableUpdate struct { // of the statement during the retry is equal to the result during the initial // attempt. func (ru *retriableUpdate) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { - c, err := tx.Update(ctx, ru.stmt) + c, err := tx.UpdateWithOptions(ctx, ru.stmt, ru.options) if err != nil && spanner.ErrCode(err) == codes.Aborted { return err } @@ -203,6 +204,7 @@ func (ru *retriableUpdate) retry(ctx context.Context, tx *spanner.ReadWriteStmtB type retriableBatchUpdate struct { // statements are the statement that were executed on Spanner. statements []spanner.Statement + options spanner.QueryOptions // c is the record counts that were returned by Spanner. c []int64 // err is the error that was returned by Spanner. @@ -213,7 +215,7 @@ type retriableBatchUpdate struct { // of the statement during the retry is equal to the result during the initial // attempt. func (ru *retriableBatchUpdate) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { - c, err := tx.BatchUpdate(ctx, ru.statements) + c, err := tx.BatchUpdateWithOptions(ctx, ru.statements, ru.options) if err != nil && spanner.ErrCode(err) == codes.Aborted { return err } @@ -313,21 +315,22 @@ func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error { // Query executes a query using the read/write transaction and returns a // rowIterator that will automatically retry the read/write transaction if the // transaction is aborted during the query or while iterating the returned rows. -func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement) rowIterator { +func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) rowIterator { // If internal retries have been disabled, we don't need to keep track of a // running checksum for all results that we have seen. if !tx.retryAborts { - return &readOnlyRowIterator{tx.rwTx.Query(ctx, stmt)} + return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, options)} } // If retries are enabled, we need to use a row iterator that will keep // track of a running checksum of all the results that we see. buffer := &bytes.Buffer{} it := &checksumRowIterator{ - RowIterator: tx.rwTx.Query(ctx, stmt), + RowIterator: tx.rwTx.QueryWithOptions(ctx, stmt, options), ctx: ctx, tx: tx, stmt: stmt, + options: options, buffer: buffer, enc: gob.NewEncoder(buffer), } @@ -335,33 +338,34 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen return it } -func (tx *readWriteTransaction) ExecContext(ctx context.Context, stmt spanner.Statement) (res int64, err error) { +func (tx *readWriteTransaction) ExecContext(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) (res int64, err error) { if tx.batch != nil { tx.batch.statements = append(tx.batch.statements, stmt) return 0, nil } if !tx.retryAborts { - return tx.rwTx.Update(ctx, stmt) + return tx.rwTx.UpdateWithOptions(ctx, stmt, options) } err = tx.runWithRetry(ctx, func(ctx context.Context) error { - res, err = tx.rwTx.Update(ctx, stmt) + res, err = tx.rwTx.UpdateWithOptions(ctx, stmt, options) return err }) tx.statements = append(tx.statements, &retriableUpdate{ - stmt: stmt, - c: res, - err: err, + stmt: stmt, + options: options, + c: res, + err: err, }) return res, err } -func (tx *readWriteTransaction) StartBatchDML() (driver.Result, error) { +func (tx *readWriteTransaction) StartBatchDML(options spanner.QueryOptions) (driver.Result, error) { if tx.batch != nil { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "This transaction already has an active batch.")) } - tx.batch = &batch{tp: dml} + tx.batch = &batch{tp: dml, options: ExecOptions{QueryOptions: options}} return driver.ResultNoRows, nil } @@ -386,21 +390,23 @@ func (tx *readWriteTransaction) AbortBatch() (driver.Result, error) { func (tx *readWriteTransaction) runDmlBatch(ctx context.Context) (driver.Result, error) { statements := tx.batch.statements + options := tx.batch.options tx.batch = nil if !tx.retryAborts { - affected, err := tx.rwTx.BatchUpdate(ctx, statements) + affected, err := tx.rwTx.BatchUpdateWithOptions(ctx, statements, options.QueryOptions) return &result{rowsAffected: sum(affected)}, err } var affected []int64 var err error err = tx.runWithRetry(ctx, func(ctx context.Context) error { - affected, err = tx.rwTx.BatchUpdate(ctx, statements) + affected, err = tx.rwTx.BatchUpdateWithOptions(ctx, statements, options.QueryOptions) return err }) tx.statements = append(tx.statements, &retriableBatchUpdate{ statements: statements, + options: options.QueryOptions, c: affected, err: err, })