From 6d66c1a245195b141ebb536e7aef4bc504051a6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 14 Jan 2025 10:49:50 +0100 Subject: [PATCH] feat: support transaction and statement tags Adds support for the use of transaction and statement (request) tags. Tags can be used to better identify where transactions and statements are coming from, and to debug performance issues. This change also introduces the generic option to pass in QueryOptions and TransactionOptions programmatically. --- README.md | 39 ++- checksum_row_iterator.go | 9 +- client_side_statement.go | 45 +++ client_side_statement_test.go | 72 +++++ client_side_statements_json.go | 48 +++ driver.go | 161 ++++++++-- driver_test.go | 16 +- driver_with_mockserver_test.go | 525 +++++++++++++++++++++++++++++++ examples/run-transaction/main.go | 5 +- examples/tags/main.go | 175 +++++++++++ stmt.go | 23 +- transaction.go | 52 +-- 12 files changed, 1103 insertions(+), 67 deletions(-) create mode 100644 examples/tags/main.go diff --git a/README.md b/README.md index f5f07771..4a714c99 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [Google Cloud Spanner](https://cloud.google.com/spanner) driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package. -``` go +```go import _ "github.com/googleapis/go-sql-spanner" db, err := sql.Open("spanner", "projects/PROJECT/instances/INSTANCE/databases/DATABASE") @@ -65,6 +65,27 @@ the same named query parameter is used in multiple places in the statement. db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374) ``` +### Query Options +Query options can be passed in as arguments to a query. Any value of type +`spanner.QueryOptions` will be passed through to the Spanner client as the +query options to use for the query or DML statement, and will be skipped +when parsing the actual query parameters. + +```go +tx.ExecContext(ctx, "INSERT INTO Singers (SingerId, Name) VALUES (@id, @name)", + spanner.QueryOptions{RequestTag: "insert_singer"}, 123, "Bruce Allison") +tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", + spanner.QueryOptions{RequestTag: "select_singer"}, 123) +``` + +Statement tags (request tags) can also be set using the custom SQL statement +`set statement_tag='my_tag'`: + +```go +tx.ExecContext(ctx, "set statement_tag = 'select_singer'") +tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", 123) +``` + ## Transactions - Read-write transactions always uses the strongest isolation level and ignore the user-specified level. @@ -72,9 +93,15 @@ db.ExecContext(ctx, "DELETE FROM tweets WHERE id = @id", 14544498215374) either Commit or Rollback. Calling either of these methods will end the current read-only transaction and return the session that is used to the session pool. -``` go +```go tx, err := db.BeginTx(ctx, &sql.TxOptions{}) // Read-write transaction. +// Read-write transaction with a transaction tag. +conn, _ := db.Conn(ctx) +_, _ := conn.ExecContext(ctx, "SET TRANSACTION_TAG='my_transaction_tag'") +tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + + tx, err := db.BeginTx(ctx, &sql.TxOptions{ ReadOnly: true, // Read-only transaction using strong reads. }) @@ -91,8 +118,8 @@ tx, err := conn.BeginTx(ctx, &sql.TxOptions{ Spanner can abort a read/write transaction if concurrent modifications are detected that would violate the transaction consistency. When this happens, the driver will return the `ErrAbortedDueToConcurrentModification` error. You can use the -`RunTransaction` function to let the driver automatically retry transactions that -are aborted by Spanner. +`RunTransaction` and `RunTransactionWithOptions` functions to let the driver +automatically retry transactions that are aborted by Spanner. ```go package sample @@ -106,14 +133,14 @@ import ( spannerdriver "github.com/googleapis/go-sql-spanner" ) -spannerdriver.RunTransaction(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { +spannerdriver.RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { row := tx.QueryRowContext(ctx, "select Name from Singers where SingerId=@id", 123) var name string if err := row.Scan(&name); err != nil { return err } return nil -}) +}, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"}) ``` See also the [transaction runner sample](./examples/run-transaction/main.go). diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index 8690df0d..fe0aaa5c 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -49,9 +49,10 @@ type checksumRowIterator struct { *spanner.RowIterator metadata *sppb.ResultSetMetadata - ctx context.Context - tx *readWriteTransaction - stmt spanner.Statement + ctx context.Context + tx *readWriteTransaction + stmt spanner.Statement + options spanner.QueryOptions // nc (nextCount) indicates the number of times that next has been called // on the iterator. Next() will be called the same number of times during // a retry. @@ -160,7 +161,7 @@ func createMetadataChecksum(enc *gob.Encoder, buffer *bytes.Buffer, metadata *sp func (it *checksumRowIterator) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { buffer := &bytes.Buffer{} enc := gob.NewEncoder(buffer) - retryIt := tx.Query(ctx, it.stmt) + retryIt := tx.QueryWithOptions(ctx, it.stmt, it.options) // If the original iterator had been stopped, we should also always stop the // new iterator. if it.stopped { diff --git a/client_side_statement.go b/client_side_statement.go index c8134858..693d0087 100644 --- a/client_side_statement.go +++ b/client_side_statement.go @@ -93,6 +93,22 @@ func (s *statementExecutor) ShowExcludeTxnFromChangeStreams(_ context.Context, c return &rows{it: it}, nil } +func (s *statementExecutor) ShowTransactionTag(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) { + it, err := createStringIterator("TransactionTag", c.TransactionTag()) + if err != nil { + return nil, err + } + return &rows{it: it}, nil +} + +func (s *statementExecutor) ShowStatementTag(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Rows, error) { + it, err := createStringIterator("StatementTag", c.StatementTag()) + if err != nil { + return nil, err + } + return &rows{it: it}, nil +} + func (s *statementExecutor) StartBatchDdl(_ context.Context, c *conn, _ string, _ []driver.NamedValue) (driver.Result, error) { return c.startBatchDDL() } @@ -147,6 +163,35 @@ func (s *statementExecutor) SetExcludeTxnFromChangeStreams(_ context.Context, c return c.setExcludeTxnFromChangeStreams(exclude) } +func (s *statementExecutor) SetTransactionTag(_ context.Context, c *conn, params string, _ []driver.NamedValue) (driver.Result, error) { + tag, err := parseTag(params) + if err != nil { + return nil, err + } + return c.setTransactionTag(tag) +} + +func (s *statementExecutor) SetStatementTag(_ context.Context, c *conn, params string, _ []driver.NamedValue) (driver.Result, error) { + tag, err := parseTag(params) + if err != nil { + return nil, err + } + return c.setStatementTag(tag) +} + +func parseTag(params string) (string, error) { + if params == "" { + return "", spanner.ToSpannerError(status.Error(codes.InvalidArgument, "no value given for tag")) + } + tag := strings.TrimSpace(params) + if !(strings.HasPrefix(tag, "'") && strings.HasSuffix(tag, "'")) { + return "", spanner.ToSpannerError(status.Error(codes.InvalidArgument, "missing single quotes around tag")) + } + tag = strings.TrimLeft(tag, "'") + tag = strings.TrimRight(tag, "'") + return tag, nil +} + var strongRegexp = regexp.MustCompile("(?i)'STRONG'") var exactStalenessRegexp = regexp.MustCompile(`(?i)'(?PEXACT_STALENESS)[\t ]+(?P(\d{1,19})(s|ms|us|ns))'`) var maxStalenessRegexp = regexp.MustCompile(`(?i)'(?PMAX_STALENESS)[\t ]+(?P(\d{1,19})(s|ms|us|ns))'`) diff --git a/client_side_statement_test.go b/client_side_statement_test.go index 29085d75..4c0e5d5a 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -374,3 +374,75 @@ func TestStatementExecutor_ExcludeTxnFromChangeStreams(t *testing.T) { } } } + +func TestStatementExecutor_SetTransactionTag(t *testing.T) { + ctx := context.Background() + for i, test := range []struct { + wantValue string + setValue string + wantSetErr bool + }{ + {"test-tag", "'test-tag'", false}, + {"other-tag", " 'other-tag'\t\n", false}, + {" tag with spaces ", "' tag with spaces '", false}, + {"", "tag-without-quotes", true}, + {"", "tag-with-missing-opening-quote'", true}, + {"", "'tag-with-missing-closing-quote", true}, + } { + c := &conn{retryAborts: true} + s := &statementExecutor{} + + it, err := s.ShowTransactionTag(ctx, c, "", nil) + if err != nil { + t.Fatalf("%d: could not get current transaction tag value from connection: %v", i, err) + } + cols := it.Columns() + wantCols := []string{"TransactionTag"} + if !cmp.Equal(cols, wantCols) { + t.Fatalf("%d: column names mismatch\nGot: %v\nWant: %v", i, cols, wantCols) + } + values := make([]driver.Value, len(cols)) + if err := it.Next(values); err != nil { + t.Fatalf("%d: failed to get first row: %v", i, err) + } + wantValues := []driver.Value{""} + if !cmp.Equal(values, wantValues) { + t.Fatalf("%d: default transaction tag mismatch\nGot: %v\nWant: %v", i, values, wantValues) + } + if err := it.Next(values); err != io.EOF { + t.Fatalf("%d: error mismatch\nGot: %v\nWant: %v", i, err, io.EOF) + } + + // Set a transaction tag. + res, err := s.SetTransactionTag(ctx, c, test.setValue, nil) + if test.wantSetErr { + if err == nil { + t.Fatalf("%d: missing expected error for value %q", i, test.setValue) + } + } else { + if err != nil { + t.Fatalf("%d: could not set new value %q for exclude: %v", i, test.setValue, err) + } + if res != driver.ResultNoRows { + t.Fatalf("%d: result mismatch\nGot: %v\nWant: %v", i, res, driver.ResultNoRows) + } + } + + // Get the tag that was set + it, err = s.ShowTransactionTag(ctx, c, "", nil) + if err != nil { + t.Fatalf("%d: could not get current transaction tag value from connection: %v", i, err) + } + if err := it.Next(values); err != nil { + t.Fatalf("%d: failed to get first row: %v", i, err) + } + wantValues = []driver.Value{test.wantValue} + if !cmp.Equal(values, wantValues) { + t.Fatalf("%d: transaction tag mismatch\nGot: %v\nWant: %v", i, values, wantValues) + } + if err := it.Next(values); err != io.EOF { + t.Fatalf("%d: error mismatch\nGot: %v\nWant: %v", i, err, io.EOF) + } + + } +} diff --git a/client_side_statements_json.go b/client_side_statements_json.go index ad848391..7c047888 100644 --- a/client_side_statements_json.go +++ b/client_side_statements_json.go @@ -60,6 +60,24 @@ var jsonFile = `{ "method": "statementShowExcludeTxnFromChangeStreams", "exampleStatements": ["show variable exclude_txn_from_change_streams"] }, + { + "name": "SHOW VARIABLE TRANSACTION_TAG", + "executorName": "ClientSideStatementNoParamExecutor", + "resultType": "RESULT_SET", + "statementType": "SHOW_TRANSACTION_TAG", + "regex": "(?is)\\A\\s*show\\s+variable\\s+transaction_tag\\s*\\z", + "method": "statementShowTransactionTag", + "exampleStatements": ["show variable transaction_tag"] + }, + { + "name": "SHOW VARIABLE STATEMENT_TAG", + "executorName": "ClientSideStatementNoParamExecutor", + "resultType": "RESULT_SET", + "statementType": "SHOW_STATEMENT_TAG", + "regex": "(?is)\\A\\s*show\\s+variable\\s+statement_tag\\s*\\z", + "method": "statementShowStatementTag", + "exampleStatements": ["show variable statement_tag"] + }, { "name": "START BATCH DDL", "executorName": "ClientSideStatementNoParamExecutor", @@ -165,6 +183,36 @@ var jsonFile = `{ "allowedValues": "(TRUE|FALSE)", "converterName": "ClientSideStatementValueConverters$BooleanConverter" } + }, + { + "name": "SET TRANSACTION_TAG = ''", + "executorName": "ClientSideStatementSetExecutor", + "resultType": "NO_RESULT", + "statementType": "SET_TRANSACTION_TAG", + "regex": "(?is)\\A\\s*set\\s+transaction_tag\\s*(?:=)\\s*(.*)\\z", + "method": "statementSetTransactionTag", + "exampleStatements": ["set transaction_tag='tag1'", "set transaction_tag='tag2'", "set transaction_tag=''", "set transaction_tag='test_tag'"], + "setStatement": { + "propertyName": "TRANSACTION_TAG", + "separator": "=", + "allowedValues": "'(.*)'", + "converterName": "ClientSideStatementValueConverters$StringValueConverter" + } + }, + { + "name": "SET STATEMENT_TAG = ''", + "executorName": "ClientSideStatementSetExecutor", + "resultType": "NO_RESULT", + "statementType": "SET_STATEMENT_TAG", + "regex": "(?is)\\A\\s*set\\s+statement_tag\\s*(?:=)\\s*(.*)\\z", + "method": "statementSetStatementTag", + "exampleStatements": ["set statement_tag='tag1'", "set statement_tag='tag2'", "set statement_tag=''", "set statement_tag='test_tag'"], + "setStatement": { + "propertyName": "STATEMENT_TAG", + "separator": "=", + "allowedValues": "'(.*)'", + "converterName": "ClientSideStatementValueConverters$StringValueConverter" + } } ] } diff --git a/driver.go b/driver.go index c851b4af..45d61f93 100644 --- a/driver.go +++ b/driver.go @@ -427,6 +427,32 @@ func (c *connector) closeClients() (err error) { // // This function will never return ErrAbortedDueToConcurrentModification. func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error) error { + return runTransactionWithOptions(ctx, db, opts, f, nil) +} + +// RunTransactionWithOptions runs the given function in a transaction on the given database. +// If the connection is a connection to a Spanner database, the transaction will +// automatically be retried if the transaction is aborted by Spanner. Any other +// errors will be propagated to the caller and the transaction will be rolled +// back. The transaction will be committed if the supplied function did not +// return an error. +// +// If the connection is to a non-Spanner database, no retries will be attempted, +// and any error that occurs during the transaction will be propagated to the +// caller. +// +// The application should *NOT* call tx.Commit() or tx.Rollback(). This is done +// automatically by this function, depending on whether the transaction function +// returned an error or not. +// +// The given spanner.TransactionOptions will be used for the transaction. +// +// This function will never return ErrAbortedDueToConcurrentModification. +func RunTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) error { + return runTransactionWithOptions(ctx, db, opts, f, &spannerOptions) +} + +func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions *spanner.TransactionOptions) error { // Get a connection from the pool that we can use to run a transaction. // Getting a connection here already makes sure that we can reserve this // connection exclusively for the duration of this method. That again @@ -452,6 +478,7 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func // It is not a Spanner connection, so just ignore and continue without any special handling. return nil } + spannerConn.withTransactionOptions(spannerOptions) origRetryAborts = spannerConn.RetryAbortsInternally() return spannerConn.SetRetryAbortsInternally(false) }); err != nil { @@ -588,6 +615,15 @@ type SpannerConn interface { // mode and for read-only transaction. SetReadOnlyStaleness(staleness spanner.TimestampBound) error + // TransactionTag returns the transaction tag that will be applied to the next + // read/write transaction on this connection. The transaction tag that is set + // on the connection is cleared when a read/write transaction is started. + TransactionTag() string + // SetTransactionTag sets the transaction tag that should be applied to the + // next read/write transaction on this connection. The tag is cleared when a + // read/write transaction is started. + SetTransactionTag(transactionTag string) error + // ExcludeTxnFromChangeStreams returns true if the next transaction should be excluded from change streams with the // DDL option `allow_txn_exclusion=true`. ExcludeTxnFromChangeStreams() bool @@ -624,8 +660,14 @@ type SpannerConn interface { // has not been aborted is not supported and will cause an error to be // returned. resetTransactionForRetry(ctx context.Context, errDuringCommit bool) error + + // setTransactionOptions sets the TransactionOptions that should be used + // for this transaction. + withTransactionOptions(options *spanner.TransactionOptions) } +var _ SpannerConn = &conn{} + type conn struct { connector *connector closed bool @@ -638,8 +680,8 @@ type conn struct { database string retryAborts bool - execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator - execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error) + execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options spanner.QueryOptions) *spanner.RowIterator + execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions, queryOptions spanner.QueryOptions) (int64, time.Time, error) execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) // batch is the currently active DDL or DML batch on this connection. @@ -652,9 +694,20 @@ type conn struct { autocommitDMLMode AutocommitDMLMode // readOnlyStaleness is used for queries in autocommit mode and for read-only transactions. readOnlyStaleness spanner.TimestampBound - // excludeTxnFromChangeStreams is used to exlude the next transaction from change streams with the DDL option + + // excludeTxnFromChangeStreams is used to exclude the next transaction from change streams with the DDL option // `allow_txn_exclusion=true` excludeTxnFromChangeStreams bool + + // transactionTag is applied to the next read/write transaction on this connection. + transactionTag string + // options is applied to the next statement that is executed on this connection. + options spanner.QueryOptions + + // txOptions is applied to the next read/write transaction on this connection. + // This value overrides all other transaction options that have been set on + // this connection (e.g. a transactionTag). + txOptions *spanner.TransactionOptions } type batchType int @@ -667,6 +720,7 @@ const ( type batch struct { tp batchType statements []spanner.Statement + options spanner.QueryOptions } // AutocommitDMLMode indicates whether a single DML statement should be executed @@ -762,6 +816,37 @@ func (c *conn) setExcludeTxnFromChangeStreams(excludeTxnFromChangeStreams bool) return driver.ResultNoRows, nil } +func (c *conn) TransactionTag() string { + return c.transactionTag +} + +func (c *conn) SetTransactionTag(transactionTag string) error { + _, err := c.setTransactionTag(transactionTag) + return err +} + +func (c *conn) setTransactionTag(transactionTag string) (driver.Result, error) { + if c.inTransaction() { + return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "cannot set transaction tag while a transaction is active")) + } + c.transactionTag = transactionTag + return driver.ResultNoRows, nil +} + +func (c *conn) StatementTag() string { + return c.options.RequestTag +} + +func (c *conn) SetStatementTag(statementTag string) error { + _, err := c.setStatementTag(statementTag) + return err +} + +func (c *conn) setStatementTag(statementTag string) (driver.Result, error) { + c.options.RequestTag = statementTag + return driver.ResultNoRows, nil +} + func (c *conn) StartBatchDDL() error { _, err := c.startBatchDDL() return err @@ -814,7 +899,7 @@ func (c *conn) startBatchDDL() (driver.Result, error) { func (c *conn) startBatchDML() (driver.Result, error) { if c.inTransaction() { - return c.tx.StartBatchDML() + return c.tx.StartBatchDML(c.createQueryOptions()) } if c.batch != nil { @@ -823,7 +908,7 @@ func (c *conn) startBatchDML() (driver.Result, error) { if c.inReadOnlyTransaction() { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "This connection has an active read-only transaction. Read-only transactions cannot execute DML batches.")) } - c.batch = &batch{tp: dml} + c.batch = &batch{tp: dml, options: c.createQueryOptions()} return driver.ResultNoRows, nil } @@ -853,8 +938,9 @@ func (c *conn) runDDLBatch(ctx context.Context) (driver.Result, error) { func (c *conn) runDMLBatch(ctx context.Context) (driver.Result, error) { statements := c.batch.statements + options := c.batch.options c.batch = nil - return c.execBatchDML(ctx, statements) + return c.execBatchDML(ctx, statements, options) } func (c *conn) abortBatch() (driver.Result, error) { @@ -894,7 +980,7 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr return driver.ResultNoRows, nil } -func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement) (driver.Result, error) { +func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement, options spanner.QueryOptions) (driver.Result, error) { if len(statements) == 0 { return &result{}, nil } @@ -906,10 +992,10 @@ func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement) if !ok { return nil, status.Errorf(codes.FailedPrecondition, "connection is in a transaction that is not a read/write transaction") } - affected, err = tx.rwTx.BatchUpdate(ctx, statements) + affected, err = tx.rwTx.BatchUpdateWithOptions(ctx, statements, options) } else { _, err = c.client.ReadWriteTransactionWithOptions(ctx, func(ctx context.Context, transaction *spanner.ReadWriteTransaction) error { - affected, err = transaction.BatchUpdate(ctx, statements) + affected, err = transaction.BatchUpdateWithOptions(ctx, statements, options) return err }, c.createTransactionOptions()) } @@ -1077,6 +1163,11 @@ func (c *conn) CheckNamedValue(value *driver.NamedValue) error { if value == nil { return nil } + options, ok := value.Value.(spanner.QueryOptions) + if ok { + c.options = options + return driver.ErrRemoveArgument + } if checkIsValidType(value.Value) { return nil } @@ -1125,14 +1216,15 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam c.commitTs = nil stmt, err := prepareSpannerStmt(query, args) + options := c.createQueryOptions() if err != nil { return nil, err } var iter rowIterator if c.tx == nil { - iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness)} + iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.readOnlyStaleness, options)} } else { - iter = c.tx.Query(ctx, stmt) + iter = c.tx.Query(ctx, stmt, options) } return &rows{it: iter}, nil } @@ -1176,7 +1268,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name c.batch.statements = append(c.batch.statements, ss) } else { if c.autocommitDMLMode == Transactional { - rowsAffected, commitTs, err = c.execSingleDMLTransactional(ctx, c.client, ss, c.createTransactionOptions()) + rowsAffected, commitTs, err = c.execSingleDMLTransactional(ctx, c.client, ss, c.createTransactionOptions(), c.createQueryOptions()) if err == nil { c.commitTs = &commitTs } @@ -1187,7 +1279,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } } } else { - rowsAffected, err = c.tx.ExecContext(ctx, ss) + rowsAffected, err = c.tx.ExecContext(ctx, ss, c.createQueryOptions()) } if err != nil { return nil, err @@ -1216,6 +1308,10 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo return c.tx.resetForRetry(ctx) } +func (c *conn) withTransactionOptions(options *spanner.TransactionOptions) { + c.txOptions = options +} + type spannerIsolationLevel sql.IsolationLevel const ( @@ -1308,18 +1404,18 @@ func (c *conn) inReadWriteTransaction() bool { return false } -func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { - return c.Single().WithTimestampBound(tb).Query(ctx, statement) +func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options spanner.QueryOptions) *spanner.RowIterator { + return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options) } -func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { +func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions, queryOptions spanner.QueryOptions) (int64, time.Time, error) { var rowsAffected int64 fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { - count, err := tx.Update(ctx, statement) + count, err := tx.UpdateWithOptions(ctx, statement, queryOptions) rowsAffected = count return err } - resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options) + resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, transactionOptions) if err != nil { return 0, time.Time{}, err } @@ -1330,14 +1426,35 @@ func execAsPartitionedDML(ctx context.Context, c *spanner.Client, statement span return c.PartitionedUpdateWithOptions(ctx, statement, options) } +func (c *conn) createQueryOptions() spanner.QueryOptions { + defer func() { + c.options = spanner.QueryOptions{} + }() + return c.options +} + func (c *conn) createTransactionOptions() spanner.TransactionOptions { - defer func() { c.excludeTxnFromChangeStreams = false }() - return spanner.TransactionOptions{ExcludeTxnFromChangeStreams: c.excludeTxnFromChangeStreams} + defer func() { + c.excludeTxnFromChangeStreams = false + c.transactionTag = "" + c.txOptions = nil + }() + if c.txOptions != nil { + return *c.txOptions + } + return spanner.TransactionOptions{ + ExcludeTxnFromChangeStreams: c.excludeTxnFromChangeStreams, + TransactionTag: c.transactionTag, + } } func (c *conn) createPartitionedDmlQueryOptions() spanner.QueryOptions { - defer func() { c.excludeTxnFromChangeStreams = false }() - return spanner.QueryOptions{ExcludeTxnFromChangeStreams: c.excludeTxnFromChangeStreams} + defer func() { + c.excludeTxnFromChangeStreams = false + c.options = spanner.QueryOptions{} + }() + c.options.ExcludeTxnFromChangeStreams = c.excludeTxnFromChangeStreams + return c.options } /* The following is the same implementation as in google-cloud-go/spanner */ diff --git a/driver_test.go b/driver_test.go index 97d401d5..f6e6f6b8 100644 --- a/driver_test.go +++ b/driver_test.go @@ -363,10 +363,10 @@ func TestConn_StartBatchDml(t *testing.T) { func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) { c := &conn{ batch: &batch{tp: ddl}, - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options spanner.QueryOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions, queryOptions spanner.QueryOptions) (int64, time.Time, error) { return 0, time.Time{}, nil }, execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { @@ -396,10 +396,10 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) { func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) { c := &conn{ batch: &batch{tp: dml}, - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options spanner.QueryOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions, queryOptions spanner.QueryOptions) (int64, time.Time, error) { return 0, time.Time{}, nil }, execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { @@ -476,10 +476,10 @@ func TestConn_GetBatchedStatements(t *testing.T) { func TestConn_GetCommitTimestampAfterAutocommitDml(t *testing.T) { want := time.Now() c := &conn{ - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options spanner.QueryOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions, queryOptions spanner.QueryOptions) (int64, time.Time, error) { return 0, want, nil }, execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { @@ -501,10 +501,10 @@ func TestConn_GetCommitTimestampAfterAutocommitDml(t *testing.T) { func TestConn_GetCommitTimestampAfterAutocommitQuery(t *testing.T) { c := &conn{ - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options spanner.QueryOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, - execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.TransactionOptions) (int64, time.Time, error) { + execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions, queryOptions spanner.QueryOptions) (int64, time.Time, error) { return 0, time.Time{}, nil }, execSingleDMLPartitioned: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options spanner.QueryOptions) (int64, error) { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 1817c426..6b2d8fc3 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -2452,6 +2452,531 @@ func TestExcludeTxnFromChangeStreams_Transaction(t *testing.T) { } } +func TestTag_Query_AutoCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + _, _ = conn.ExecContext(ctx, "set statement_tag = 'tag_1'") + iter, err := conn.QueryContext(ctx, testutil.SelectFooFromBar) + if err != nil { + t.Fatalf("failed to execute query: %v", err) + } + // Just consume the results to ensure that the query is executed. + for iter.Next() { + if iter.Err() != nil { + t.Fatal(iter.Err()) + } + } + iter.Close() + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 1; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + execRequest := execRequests[0].(*sppb.ExecuteSqlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, "tag_1"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the statement. + var statementTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE STATEMENT_TAG").Scan(&statementTag); err != nil { + t.Fatalf("failed to get statement_tag: %v", err) + } + if g, w := statementTag, ""; g != w { + t.Fatalf("statement_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_Update_AutoCommit(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + _, _ = conn.ExecContext(ctx, "set statement_tag = 'tag_1'") + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 1; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + execRequest := execRequests[0].(*sppb.ExecuteSqlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, "tag_1"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequest := commitRequests[0].(*sppb.CommitRequest) + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the statement. + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_AutoCommit_BatchDml(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + _, _ = conn.ExecContext(ctx, "set statement_tag = 'tag_1'") + _, _ = conn.ExecContext(ctx, "start batch dml") + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = conn.ExecContext(ctx, "run batch") + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteBatchDmlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(execRequests), 1; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + execRequest := execRequests[0].(*sppb.ExecuteBatchDmlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, "tag_1"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequest := commitRequests[0].(*sppb.CommitRequest) + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the statement. + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_ReadWriteTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_1'") + rows, _ := tx.QueryContext(ctx, testutil.SelectFooFromBar) + for rows.Next() { + } + rows.Close() + + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_2'") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_3'") + _, _ = tx.ExecContext(ctx, "start batch dml") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = tx.ExecContext(ctx, "run batch") + _ = tx.Commit() + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 2; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(execRequests); i++ { + execRequest := execRequests[i].(*sppb.ExecuteSqlRequest) + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := execRequest.RequestOptions.RequestTag, fmt.Sprintf("tag_%d", (i%2)+1); g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("number of batch request mismatch\n Got: %v\nWant: %v", g, w) + } + batchRequest := batchRequests[0].(*sppb.ExecuteBatchDmlRequest) + if g, w := batchRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := batchRequest.RequestOptions.RequestTag, "tag_3"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequest := commitRequests[0].(*sppb.CommitRequest) + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + + // Verify that the tag is reset after the transaction. + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTag_ReadWriteTransaction_Retry(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + for _, useArgs := range []bool{false, true} { + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + _, _ = conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'") + tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + + var rows *sql.Rows + if useArgs { + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar, spanner.QueryOptions{RequestTag: "tag_1"}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_1'") + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar) + } + for rows.Next() { + } + rows.Close() + + if useArgs { + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo, spanner.QueryOptions{RequestTag: "tag_2"}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_2'") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + } + + if useArgs { + _, _ = tx.ExecContext(ctx, "start batch dml", spanner.QueryOptions{RequestTag: "tag_3"}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_3'") + _, _ = tx.ExecContext(ctx, "start batch dml") + } + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = tx.ExecContext(ctx, "run batch") + + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + _ = tx.Commit() + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 4; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(execRequests); i++ { + execRequest := execRequests[i].(*sppb.ExecuteSqlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 2 { + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := execRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := execRequest.RequestOptions.RequestTag, fmt.Sprintf("tag_%d", (i%2)+1); g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 2; g != w { + t.Fatalf("number of batch request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(batchRequests); i++ { + batchRequest := batchRequests[i].(*sppb.ExecuteBatchDmlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := batchRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := batchRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := batchRequest.RequestOptions.RequestTag, "tag_3"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 2; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(commitRequests); i++ { + commitRequest := commitRequests[i].(*sppb.CommitRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := commitRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + } + + // Verify that the tag is reset after the transaction. + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestTag_RunTransaction_Retry(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, server, teardown := setupTestDBConnection(t) + defer teardown() + + conn, err := db.Conn(ctx) + defer func() { + if err := conn.Close(); err != nil { + t.Fatal(err) + } + }() + if err != nil { + t.Fatalf("failed to get a connection: %v", err) + } + + for _, useArgs := range []bool{false, true} { + attempts := 0 + err = RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + attempts++ + var rows *sql.Rows + if useArgs { + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar, spanner.QueryOptions{RequestTag: "tag_1"}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_1'") + rows, _ = tx.QueryContext(ctx, testutil.SelectFooFromBar) + } + for rows.Next() { + } + rows.Close() + + if useArgs { + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo, spanner.QueryOptions{RequestTag: "tag_2"}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag='tag_2'") + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + } + + if useArgs { + _, _ = tx.ExecContext(ctx, "start batch dml", spanner.QueryOptions{RequestTag: "tag_3"}) + } else { + _, _ = tx.ExecContext(ctx, "set statement_tag = 'tag_3'") + _, _ = tx.ExecContext(ctx, "start batch dml") + } + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _, _ = tx.ExecContext(ctx, "run batch") + if attempts == 1 { + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + } + return nil + }, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"}) + if err != nil { + t.Fatalf("failed to run transaction: %v", err) + } + + requests := drainRequestsFromServer(server.TestSpanner) + // The ExecuteSqlRequest and CommitRequest should have a transaction tag. + execRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(execRequests), 4; g != w { + t.Fatalf("number of execute requests mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(execRequests); i++ { + execRequest := execRequests[i].(*sppb.ExecuteSqlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 2 { + if g, w := execRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := execRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := execRequest.RequestOptions.RequestTag, fmt.Sprintf("tag_%d", (i%2)+1); g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 2; g != w { + t.Fatalf("number of batch request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(batchRequests); i++ { + batchRequest := batchRequests[i].(*sppb.ExecuteBatchDmlRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := batchRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := batchRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + if g, w := batchRequest.RequestOptions.RequestTag, "tag_3"; g != w { + t.Fatalf("statement tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + + commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) + if g, w := len(commitRequests), 2; g != w { + t.Fatalf("number of commit request mismatch\n Got: %v\nWant: %v", g, w) + } + for i := 0; i < len(commitRequests); i++ { + commitRequest := commitRequests[i].(*sppb.CommitRequest) + // TODO: Remove when https://github.com/googleapis/google-cloud-go/pull/11443 + // has been merged. + if i < 1 { + if g, w := commitRequest.RequestOptions.TransactionTag, "my_transaction_tag"; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + if g, w := commitRequest.RequestOptions.TransactionTag, ""; g != w { + t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w) + } + } + } + + // Verify that the transaction tag is reset after the transaction. + var transactionTag string + if err := conn.QueryRowContext(ctx, "SHOW VARIABLE TRANSACTION_TAG").Scan(&transactionTag); err != nil { + t.Fatalf("failed to get transaction_tag: %v", err) + } + if g, w := transactionTag, ""; g != w { + t.Fatalf("transaction_tag mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + func TestMaxIdleConnectionsNonZero(t *testing.T) { t.Parallel() diff --git a/examples/run-transaction/main.go b/examples/run-transaction/main.go index c9c2a327..c423e2a9 100644 --- a/examples/run-transaction/main.go +++ b/examples/run-transaction/main.go @@ -20,6 +20,7 @@ import ( "fmt" "sync" + "cloud.google.com/go/spanner" spannerdriver "github.com/googleapis/go-sql-spanner" "github.com/googleapis/go-sql-spanner/examples" ) @@ -59,7 +60,7 @@ func runTransaction(projectId, instanceId, databaseId string) error { // will be aborted and retried by Spanner multiple times. The end result // will still be that all transactions succeed and the name contains all // indexes in an undefined order. - errors[index] = spannerdriver.RunTransaction(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + errors[index] = spannerdriver.RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { // Query the singer in the transaction. This will take a lock on the row and guarantee that // the value that we read is still the same when the transaction is committed. If not, Spanner // will abort the transaction, and the transaction will be retried. @@ -82,7 +83,7 @@ func runTransaction(projectId, instanceId, databaseId string) error { return fmt.Errorf("unexpected affected row count: %d", affected) } return nil - }) + }, spanner.TransactionOptions{TransactionTag: "sample_transaction"}) }() } wg.Wait() diff --git a/examples/tags/main.go b/examples/tags/main.go new file mode 100644 index 00000000..e14744fa --- /dev/null +++ b/examples/tags/main.go @@ -0,0 +1,175 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "database/sql" + "fmt" + + "cloud.google.com/go/spanner" + spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/examples" +) + +var createTableStatement = "CREATE TABLE Singers (SingerId INT64, Name STRING(MAX)) PRIMARY KEY (SingerId)" + +// Example for using transaction tags and statement tags through SQL statements. +// +// Tags can also be set programmatically using spannerdriver.RunTransactionWithOptions +// and the query argument spannerdriver.StatementTag. +// +// Execute the sample with the command `go run main.go` from this directory. +func tagsWithSqlStatements(projectId, instanceId, databaseId string) error { + fmt.Println("Running sample for setting tags with SQL statements") + + ctx := context.Background() + db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId)) + if err != nil { + return err + } + defer db.Close() + + // Obtain a connection for the database in order to ensure that we + // set the transaction tag on the same connection as the connection + // that will execute the transaction. + conn, err := db.Conn(ctx) + if err != nil { + return err + } + + // Set a transaction tag on the connection and start a transaction. + // This transaction tag will be applied to the next transaction that + // is executed by this connection. The transaction tag is automatically + // included with all statements of that transaction. + if _, err := conn.ExecContext(ctx, "set transaction_tag = 'my_transaction_tag'"); err != nil { + return err + } + fmt.Println("Executing transaction with transaction tag 'my_transaction_tag'") + tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return err + } + + // Set a statement tag and insert a new record using the transaction that we just started. + if _, err := tx.ExecContext(ctx, "set statement_tag = 'insert_singer'"); err != nil { + _ = tx.Rollback() + return err + } + fmt.Println("Executing statement with tag 'insert_singer'") + _, err = tx.ExecContext(ctx, "INSERT INTO Singers (SingerId, Name) VALUES (@id, @name)", 123, "Bruce Allison") + if err != nil { + _ = tx.Rollback() + return err + } + + // Set another statement tag and execute a query. + if _, err := tx.ExecContext(ctx, "set statement_tag = 'select_singer'"); err != nil { + _ = tx.Rollback() + return err + } + fmt.Println("Executing statement with tag 'select_singer'") + rows, err := tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", 123) + if err != nil { + _ = tx.Rollback() + return err + } + var ( + id int64 + name string + ) + for rows.Next() { + if err := rows.Scan(&id, &name); err != nil { + _ = tx.Rollback() + return err + } + fmt.Printf("Found singer: %v %v\n", id, name) + } + if err := rows.Err(); err != nil { + _ = tx.Rollback() + return err + } + _ = rows.Close() + if err := tx.Commit(); err != nil { + return err + } + + fmt.Println("Finished transaction with tag 'my_transaction_tag'") + + return nil +} + +// tagsProgrammatically shows how to set transaction tags and statement tags +// programmatically. +// +// Note: It is not recommended to mix using SQL statements and passing in +// tags or other QueryOptions programmatically. +func tagsProgrammatically(projectId, instanceId, databaseId string) error { + fmt.Println("Running sample for setting tags programmatically") + + ctx := context.Background() + db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId)) + if err != nil { + return err + } + defer db.Close() + + // Use RunTransactionWithOptions to set a transaction tag programmatically. + fmt.Println("Executing transaction with transaction tag 'my_transaction_tag'") + if err := spannerdriver.RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + fmt.Println("Executing statement with tag 'insert_singer'") + // Pass in a value of spanner.QueryOptions to specify the options that should be used for a DML statement. + _, err = tx.ExecContext(ctx, "INSERT INTO Singers (SingerId, Name) VALUES (@id, @name)", spanner.QueryOptions{RequestTag: "insert_singer"}, 123, "Bruce Allison") + if err != nil { + return err + } + + fmt.Println("Executing statement with tag 'select_singer'") + // Pass in a value of spanner.QueryOptions to specify the options that should be used for a query. + rows, err := tx.QueryContext(ctx, "SELECT SingerId, Name FROM Singers WHERE SingerId = ?", spanner.QueryOptions{RequestTag: "select_singer"}, 123) + if err != nil { + return err + } + var ( + id int64 + name string + ) + for rows.Next() { + if err := rows.Scan(&id, &name); err != nil { + _ = tx.Rollback() + return err + } + fmt.Printf("Found singer: %v %v\n", id, name) + } + if err := rows.Err(); err != nil { + return err + } + _ = rows.Close() + + return nil + }, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"}); err != nil { + return err + } + + fmt.Println("Finished transaction with tag 'my_transaction_tag'") + + return nil +} + +func main() { + examples.RunSampleOnEmulator(tagsWithSqlStatements, createTableStatement) + fmt.Println() + examples.RunSampleOnEmulator(tagsProgrammatically, createTableStatement) +} diff --git a/stmt.go b/stmt.go index f4e7881c..7d4f8429 100644 --- a/stmt.go +++ b/stmt.go @@ -27,6 +27,7 @@ type stmt struct { conn *conn numArgs int query string + options *spanner.QueryOptions } func (s *stmt) Close() error { @@ -55,15 +56,33 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv return nil, err } + var options spanner.QueryOptions + if s.options != nil { + options = *s.options + } else { + options = s.conn.createQueryOptions() + } var it rowIterator if s.conn.tx != nil { - it = s.conn.tx.Query(ctx, ss) + it = s.conn.tx.Query(ctx, ss, options) } else { - it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).Query(ctx, ss)} + it = &readOnlyRowIterator{s.conn.client.Single().WithTimestampBound(s.conn.readOnlyStaleness).QueryWithOptions(ctx, ss, options)} } return &rows{it: it}, nil } +func (s *stmt) CheckNamedValue(value *driver.NamedValue) error { + if value == nil { + return nil + } + options, ok := value.Value.(spanner.QueryOptions) + if ok { + s.options = &options + return driver.ErrRemoveArgument + } + return nil +} + func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement, error) { q, names, err := parseParameters(q) if err != nil { diff --git a/transaction.go b/transaction.go index 465024a5..31951fa7 100644 --- a/transaction.go +++ b/transaction.go @@ -33,10 +33,10 @@ type contextTransaction interface { Commit() error Rollback() error resetForRetry(ctx context.Context) error - Query(ctx context.Context, stmt spanner.Statement) rowIterator - ExecContext(ctx context.Context, stmt spanner.Statement) (int64, error) + Query(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) rowIterator + ExecContext(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) (int64, error) - StartBatchDML() (driver.Result, error) + StartBatchDML(options spanner.QueryOptions) (driver.Result, error) RunBatch(ctx context.Context) (driver.Result, error) AbortBatch() (driver.Result, error) @@ -95,15 +95,15 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error { return nil } -func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement) rowIterator { - return &readOnlyRowIterator{tx.roTx.Query(ctx, stmt)} +func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) rowIterator { + return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, options)} } -func (tx *readOnlyTransaction) ExecContext(_ context.Context, stmt spanner.Statement) (int64, error) { +func (tx *readOnlyTransaction) ExecContext(_ context.Context, _ spanner.Statement, _ spanner.QueryOptions) (int64, error) { return 0, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "read-only transactions cannot write")) } -func (tx *readOnlyTransaction) StartBatchDML() (driver.Result, error) { +func (tx *readOnlyTransaction) StartBatchDML(_ spanner.QueryOptions) (driver.Result, error) { return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "read-only transactions cannot write")) } @@ -175,7 +175,8 @@ type retriableStatement interface { // retriableUpdate implements retriableStatement for update statements. type retriableUpdate struct { // stmt is the statement that was executed on Spanner. - stmt spanner.Statement + stmt spanner.Statement + options spanner.QueryOptions // c is the record count that was returned by Spanner. c int64 // err is the error that was returned by Spanner. @@ -186,7 +187,7 @@ type retriableUpdate struct { // of the statement during the retry is equal to the result during the initial // attempt. func (ru *retriableUpdate) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { - c, err := tx.Update(ctx, ru.stmt) + c, err := tx.UpdateWithOptions(ctx, ru.stmt, ru.options) if err != nil && spanner.ErrCode(err) == codes.Aborted { return err } @@ -203,6 +204,7 @@ func (ru *retriableUpdate) retry(ctx context.Context, tx *spanner.ReadWriteStmtB type retriableBatchUpdate struct { // statements are the statement that were executed on Spanner. statements []spanner.Statement + options spanner.QueryOptions // c is the record counts that were returned by Spanner. c []int64 // err is the error that was returned by Spanner. @@ -213,7 +215,7 @@ type retriableBatchUpdate struct { // of the statement during the retry is equal to the result during the initial // attempt. func (ru *retriableBatchUpdate) retry(ctx context.Context, tx *spanner.ReadWriteStmtBasedTransaction) error { - c, err := tx.BatchUpdate(ctx, ru.statements) + c, err := tx.BatchUpdateWithOptions(ctx, ru.statements, ru.options) if err != nil && spanner.ErrCode(err) == codes.Aborted { return err } @@ -313,21 +315,22 @@ func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error { // Query executes a query using the read/write transaction and returns a // rowIterator that will automatically retry the read/write transaction if the // transaction is aborted during the query or while iterating the returned rows. -func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement) rowIterator { +func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) rowIterator { // If internal retries have been disabled, we don't need to keep track of a // running checksum for all results that we have seen. if !tx.retryAborts { - return &readOnlyRowIterator{tx.rwTx.Query(ctx, stmt)} + return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, options)} } // If retries are enabled, we need to use a row iterator that will keep // track of a running checksum of all the results that we see. buffer := &bytes.Buffer{} it := &checksumRowIterator{ - RowIterator: tx.rwTx.Query(ctx, stmt), + RowIterator: tx.rwTx.QueryWithOptions(ctx, stmt, options), ctx: ctx, tx: tx, stmt: stmt, + options: options, buffer: buffer, enc: gob.NewEncoder(buffer), } @@ -335,33 +338,34 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen return it } -func (tx *readWriteTransaction) ExecContext(ctx context.Context, stmt spanner.Statement) (res int64, err error) { +func (tx *readWriteTransaction) ExecContext(ctx context.Context, stmt spanner.Statement, options spanner.QueryOptions) (res int64, err error) { if tx.batch != nil { tx.batch.statements = append(tx.batch.statements, stmt) return 0, nil } if !tx.retryAborts { - return tx.rwTx.Update(ctx, stmt) + return tx.rwTx.UpdateWithOptions(ctx, stmt, options) } err = tx.runWithRetry(ctx, func(ctx context.Context) error { - res, err = tx.rwTx.Update(ctx, stmt) + res, err = tx.rwTx.UpdateWithOptions(ctx, stmt, options) return err }) tx.statements = append(tx.statements, &retriableUpdate{ - stmt: stmt, - c: res, - err: err, + stmt: stmt, + options: options, + c: res, + err: err, }) return res, err } -func (tx *readWriteTransaction) StartBatchDML() (driver.Result, error) { +func (tx *readWriteTransaction) StartBatchDML(options spanner.QueryOptions) (driver.Result, error) { if tx.batch != nil { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "This transaction already has an active batch.")) } - tx.batch = &batch{tp: dml} + tx.batch = &batch{tp: dml, options: options} return driver.ResultNoRows, nil } @@ -386,21 +390,23 @@ func (tx *readWriteTransaction) AbortBatch() (driver.Result, error) { func (tx *readWriteTransaction) runDmlBatch(ctx context.Context) (driver.Result, error) { statements := tx.batch.statements + options := tx.batch.options tx.batch = nil if !tx.retryAborts { - affected, err := tx.rwTx.BatchUpdate(ctx, statements) + affected, err := tx.rwTx.BatchUpdateWithOptions(ctx, statements, options) return &result{rowsAffected: sum(affected)}, err } var affected []int64 var err error err = tx.runWithRetry(ctx, func(ctx context.Context) error { - affected, err = tx.rwTx.BatchUpdate(ctx, statements) + affected, err = tx.rwTx.BatchUpdateWithOptions(ctx, statements, options) return err }) tx.statements = append(tx.statements, &retriableBatchUpdate{ statements: statements, + options: options, c: affected, err: err, })