From 7f90ff1df6ced91daeb8e6b6398648323f32d26a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sun, 5 Oct 2025 13:14:43 +0200 Subject: [PATCH 1/4] feat: parse SET TRANSACTION statements Parse SET TRANSACTION statements and translate these to SET LOCAL statements. SET TRANSACTION may only be executed in a transaction block, and can only be used for a specific, limited set of connection properties. The syntax is specified by the SQL standard and PostgreSQL. See also https://www.postgresql.org/docs/current/sql-set-transaction.html This change only adds partial support. The following features will be added in future changes: 1. SET TRANSACTION READ {WRITE | ONLY} is not picked up by the driver, as the type of transaction is set directly when BeginTx is called. A refactor of this transaction handling is needed to be able to pick up SET TRANSACTION READ ONLY / SET TRANSACTION READ WRITE statements that are executed after BeginTx has been called. 2. PostgreSQL allows multiple transaction modes to be set in a single SET TRANSACTION statement. E.g. the following is allowed: SET TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE The current implementation only supports one transaction mode per SET statement. --- conn.go | 11 +- connection_properties.go | 18 ++++ connectionstate/connection_state.go | 5 +- parser/simple_parser.go | 14 ++- parser/statement_parser_test.go | 34 ++++++ parser/statements.go | 115 +++++++++++++++++++- parser/statements_test.go | 162 +++++++++++++++++++++++----- statements.go | 8 +- transaction_test.go | 107 ++++++++++++++++++ 9 files changed, 436 insertions(+), 38 deletions(-) create mode 100644 transaction_test.go diff --git a/conn.go b/conn.go index 31dc2c41..5b2ca244 100644 --- a/conn.go +++ b/conn.go @@ -332,13 +332,18 @@ func (c *conn) showConnectionVariable(identifier parser.Identifier) (any, bool, return c.state.GetValue(extension, name) } -func (c *conn) setConnectionVariable(identifier parser.Identifier, value string, local bool) error { +func (c *conn) setConnectionVariable(identifier parser.Identifier, value string, local bool, transaction bool) error { + if transaction && !local { + // When transaction == true, then local must also be true. + // We should never hit this condition, as this is an indication of a bug in the driver code. + return status.Errorf(codes.FailedPrecondition, "transaction properties must be set as a local value") + } extension, name, err := toExtensionAndName(identifier) if err != nil { return err } if local { - return c.state.SetLocalValue(extension, name, value) + return c.state.SetLocalValue(extension, name, value, transaction) } return c.state.SetValue(extension, name, value, connectionstate.ContextUser) } @@ -1144,6 +1149,8 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver } }() + // TODO: Delay the actual determination of the transaction type until the first query. + // This is required in order to support SET TRANSACTION READ {ONLY | WRITE} readOnlyTxOpts := c.getReadOnlyTransactionOptions() batchReadOnlyTxOpts := c.getBatchReadOnlyTransactionOptions() if c.inTransaction() { diff --git a/connection_properties.go b/connection_properties.go index be6716b7..bedf5561 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -228,6 +228,24 @@ var propertyDecodeToNativeArrays = createConnectionProperty( // Transaction connection properties. // ------------------------------------------------------------------------------------------------ +var propertyTransactionReadOnly = createConnectionProperty( + "transaction_read_only", + "transaction_read_only is the default read-only mode for transactions on this connection.", + false, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertBool, +) +var propertyTransactionDeferrable = createConnectionProperty( + "transaction_deferrable", + "transaction_deferrable is a no-op on Spanner. It is defined in this driver for compatibility with PostgreSQL.", + false, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertBool, +) var propertyExcludeTxnFromChangeStreams = createConnectionProperty( "exclude_txn_from_change_streams", "exclude_txn_from_change_streams determines whether transactions on this connection should be excluded from "+ diff --git a/connectionstate/connection_state.go b/connectionstate/connection_state.go index 6fd7131d..77417eea 100644 --- a/connectionstate/connection_state.go +++ b/connectionstate/connection_state.go @@ -131,7 +131,10 @@ func (cs *ConnectionState) SetValue(extension, name, value string, context Conte return cs.setValue(extension, name, value, context, false) } -func (cs *ConnectionState) SetLocalValue(extension, name, value string) error { +func (cs *ConnectionState) SetLocalValue(extension, name, value string, isSetTransaction bool) error { + if isSetTransaction && !cs.inTransaction { + return status.Error(codes.FailedPrecondition, "SET TRANSACTION can only be used in transaction blocks") + } return cs.setValue(extension, name, value, ContextUser, true) } diff --git a/parser/simple_parser.go b/parser/simple_parser.go index 9a2da379..e8fef032 100644 --- a/parser/simple_parser.go +++ b/parser/simple_parser.go @@ -290,6 +290,16 @@ func (p *simpleParser) eatKeywords(keywords []string) bool { return true } +// peekKeyword checks if the next keyword is the given keyword. +// The position of the parser is not updated. +func (p *simpleParser) peekKeyword(keyword string) bool { + pos := p.pos + defer func() { + p.pos = pos + }() + return p.eatKeyword(keyword) +} + // eatKeyword eats the given keyword at the current position of the parser if it exists // and returns true if the keyword was found. Otherwise, it returns false. func (p *simpleParser) eatKeyword(keyword string) bool { @@ -323,8 +333,8 @@ func (p *simpleParser) readKeyword() string { if isSpace(p.sql[p.pos]) { break } - // Only upper/lower-case letters are allowed in keywords. - if !((p.sql[p.pos] >= 'A' && p.sql[p.pos] <= 'Z') || (p.sql[p.pos] >= 'a' && p.sql[p.pos] <= 'z')) { + // Only upper/lower-case letters and underscores are allowed in keywords. + if !((p.sql[p.pos] >= 'A' && p.sql[p.pos] <= 'Z') || (p.sql[p.pos] >= 'a' && p.sql[p.pos] <= 'z')) && p.sql[p.pos] != '_' { break } } diff --git a/parser/statement_parser_test.go b/parser/statement_parser_test.go index ed50be2b..2ab64399 100644 --- a/parser/statement_parser_test.go +++ b/parser/statement_parser_test.go @@ -2135,6 +2135,10 @@ func TestReadKeyword(t *testing.T) { input: "Select from my_table", want: "Select", }, + { + input: "statement_tag", + want: "statement_tag", + }, } statementParser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) if err != nil { @@ -2404,6 +2408,36 @@ func TestCachedParamsAreImmutable(t *testing.T) { } } +func TestPeekKeyword(t *testing.T) { + t.Parallel() + + parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) + if err != nil { + t.Fatal(err) + } + sp := &simpleParser{sql: []byte("select * from foo"), statementParser: parser} + if !sp.peekKeyword("select") { + t.Fatal("peekKeyword should have returned true") + } + if g, w := sp.pos, 0; g != w { + t.Fatalf("position mismatch\n Got: %v\nWant: %v", g, w) + } + + if !sp.eatKeyword("select") { + t.Fatal("eatKeyword should have returned true") + } + if !sp.eatToken('*') { + t.Fatal("eatToken should have returned true") + } + pos := sp.pos + if !sp.peekKeyword("from") { + t.Fatal("peekKeyword should have returned true") + } + if g, w := sp.pos, pos; g != w { + t.Fatalf("position mismatch\n Got: %v\nWant: %v", g, w) + } +} + func TestEatKeyword(t *testing.T) { t.Parallel() diff --git a/parser/statements.go b/parser/statements.go index 2dc1ce93..abfc765d 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -15,6 +15,8 @@ package parser import ( + "fmt" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -142,11 +144,19 @@ func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error // ParsedSetStatement is a statement of the form // SET [SESSION | LOCAL] [my_extension.]my_property {=|to} +// +// It also covers statements of the form SET TRANSACTION. This is a +// synonym for SET LOCAL, but is only supported for a specific set of +// properties, and may only be executed before a transaction has been +// activated. Examples include: +// SET TRANSACTION READ ONLY +// SET TRANSACTION ISOLATION LEVEL [SERIALIZABLE | REPEATABLE READ] type ParsedSetStatement struct { - query string - Identifier Identifier - Literal Literal - IsLocal bool + query string + Identifier Identifier + Literal Literal + IsLocal bool + IsTransaction bool } func (s *ParsedSetStatement) Name() string { @@ -165,10 +175,17 @@ func (s *ParsedSetStatement) parse(parser *StatementParser, query string) error return status.Errorf(codes.InvalidArgument, "syntax error: expected SET") } isLocal := sp.eatKeyword("LOCAL") - if !isLocal && parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + isTransaction := false + if !isLocal { + isTransaction = sp.eatKeyword("TRANSACTION") + } + if !isLocal && !isTransaction && parser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { // Just eat and ignore the SESSION keyword if it exists, as SESSION is the default. _ = sp.eatKeyword("SESSION") } + if isTransaction { + return s.parseSetTransaction(sp, query) + } identifier, err := sp.eatIdentifier() if err != nil { return err @@ -197,6 +214,93 @@ func (s *ParsedSetStatement) parse(parser *StatementParser, query string) error return nil } +func (s *ParsedSetStatement) parseSetTransaction(sp *simpleParser, query string) error { + if !sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "syntax error: missing TRANSACTION OPTION, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + } + // TODO: Support multiple transaction mode settings in one statement: + // SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE + if sp.peekKeyword("ISOLATION") { + return s.parseSetTransactionIsolationLevel(sp, query) + } + if sp.peekKeyword("READ") { + return s.parseSetTransactionMode(sp, query) + } + if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { + // https://www.postgresql.org/docs/current/sql-set-transaction.html + if sp.peekKeyword("DEFERRABLE") || sp.peekKeyword("NOT") { + return s.parseSetTransactionDeferrable(sp, query) + } + } + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") +} + +func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, query string) error { + if !sp.eatKeywords([]string{"ISOLATION", "LEVEL"}) { + return status.Errorf(codes.InvalidArgument, "syntax error: expected ISOLATION LEVEL") + } + var value Literal + if sp.eatKeyword("SERIALIZABLE") { + value = Literal{Value: "serializable"} + } else if sp.eatKeywords([]string{"REPEATABLE", "READ"}) { + value = Literal{Value: "repeatable_read"} + } else { + return status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ") + } + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + + s.query = query + s.Identifier = Identifier{Parts: []string{"isolation_level"}} + s.Literal = value + s.IsLocal = true + s.IsTransaction = true + return nil +} + +func (s *ParsedSetStatement) parseSetTransactionMode(sp *simpleParser, query string) error { + readOnly := false + if sp.eatKeywords([]string{"READ", "ONLY"}) { + readOnly = true + } else if sp.eatKeywords([]string{"READ", "WRITE"}) { + readOnly = false + } else { + return status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE") + } + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + + s.query = query + s.Identifier = Identifier{Parts: []string{"transaction_read_only"}} + s.Literal = Literal{Value: fmt.Sprintf("%v", readOnly)} + s.IsLocal = true + s.IsTransaction = true + return nil +} + +func (s *ParsedSetStatement) parseSetTransactionDeferrable(sp *simpleParser, query string) error { + deferrable := false + if sp.eatKeywords([]string{"NOT", "DEFERRABLE"}) { + deferrable = false + } else if sp.eatKeyword("DEFERRABLE") { + deferrable = true + } else { + return status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE") + } + if sp.hasMoreTokens() { + return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + } + + s.query = query + s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} + s.Literal = Literal{Value: fmt.Sprintf("%v", deferrable)} + s.IsLocal = true + s.IsTransaction = true + return nil +} + // ParsedResetStatement is a statement of the form // RESET [my_extension.]my_property type ParsedResetStatement struct { @@ -404,6 +508,7 @@ func (s *ParsedBeginStatement) parse(parser *StatementParser, query string) erro // Parse a statement of the form // GoogleSQL: BEGIN [TRANSACTION] // PostgreSQL: {START | BEGIN} [{TRANSACTION | WORK}] (https://www.postgresql.org/docs/current/sql-begin.html) + // TODO: Support transaction modes in the BEGIN / START statement. sp := &simpleParser{sql: []byte(query), statementParser: parser} if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { if !sp.eatKeyword("START") && !sp.eatKeyword("BEGIN") { diff --git a/parser/statements_test.go b/parser/statements_test.go index a7bb1d1c..b0718d57 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -15,6 +15,7 @@ package parser import ( + "fmt" "reflect" "strings" "testing" @@ -128,6 +129,7 @@ func TestParseSetStatement(t *testing.T) { type test struct { input string want ParsedSetStatement + onlyPg bool wantErr bool } tests := []test{ @@ -180,6 +182,116 @@ func TestParseSetStatement(t *testing.T) { Literal: Literal{Value: "value"}, }, }, + { + input: "set transaction isolation level serializable", + want: ParsedSetStatement{ + query: "set transaction isolation level serializable", + Identifier: Identifier{Parts: []string{"isolation_level"}}, + Literal: Literal{Value: "serializable"}, + IsLocal: true, + IsTransaction: true, + }, + }, + { + input: "set transaction isolation serializable", + wantErr: true, + }, + { + input: "set transaction isolation level serializable foo", + wantErr: true, + }, + { + input: "set isolation level serializable", + wantErr: true, + }, + { + input: "set transaction isolation level serialisable", + wantErr: true, + }, + { + input: "set transaction isolation level repeatable read", + want: ParsedSetStatement{ + query: "set transaction isolation level repeatable read", + Identifier: Identifier{Parts: []string{"isolation_level"}}, + Literal: Literal{Value: "repeatable_read"}, + IsLocal: true, + IsTransaction: true, + }, + }, + { + input: "set transaction isolation level repeatable", + wantErr: true, + }, + { + input: "set transaction isolation level read", + wantErr: true, + }, + { + input: "set transaction isolation level repeatable read serializable", + wantErr: true, + }, + { + input: "set transaction isolation level serializable repeatable read", + wantErr: true, + }, + { + input: "set transaction read write", + want: ParsedSetStatement{ + query: "set transaction read write", + Identifier: Identifier{Parts: []string{"transaction_read_only"}}, + Literal: Literal{Value: "false"}, + IsLocal: true, + IsTransaction: true, + }, + }, + { + input: "set transaction read only", + want: ParsedSetStatement{ + query: "set transaction read only", + Identifier: Identifier{Parts: []string{"transaction_read_only"}}, + Literal: Literal{Value: "true"}, + IsLocal: true, + IsTransaction: true, + }, + }, + { + input: "set transaction read", + wantErr: true, + }, + { + input: "set transaction write", + wantErr: true, + }, + { + input: "set transaction read only write", + wantErr: true, + }, + { + input: "set transaction write only", + wantErr: true, + }, + { + onlyPg: true, + input: "set transaction deferrable", + want: ParsedSetStatement{ + query: "set transaction deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + Literal: Literal{Value: "true"}, + IsLocal: true, + IsTransaction: true, + }, + }, + { + onlyPg: true, + input: "set transaction not deferrable", + want: ParsedSetStatement{ + query: "set transaction not deferrable", + Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, + Literal: Literal{Value: "false"}, + IsLocal: true, + IsTransaction: true, + }, + }, { input: "set my_property =", wantErr: true, @@ -205,31 +317,33 @@ func TestParseSetStatement(t *testing.T) { wantErr: true, }, } - parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) - if err != nil { - t.Fatal(err) - } - keyword := "SET" - for _, test := range tests { - t.Run(test.input, func(t *testing.T) { - stmt, err := parseStatement(parser, keyword, test.input) - if test.wantErr { - if err == nil { - t.Fatalf("parseStatement(%q) should have failed", test.input) - } - } else { - if err != nil { - t.Fatal(err) - } - showStmt, ok := stmt.(*ParsedSetStatement) - if !ok { - t.Fatalf("parseStatement(%q) should have returned a *parsedSetStatement", test.input) + for _, dialect := range []databasepb.DatabaseDialect{databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, databasepb.DatabaseDialect_POSTGRESQL} { + parser, err := NewStatementParser(dialect, 1000) + if err != nil { + t.Fatal(err) + } + keyword := "SET" + for _, test := range tests { + t.Run(fmt.Sprintf("%s %s", dialect, test.input), func(t *testing.T) { + stmt, err := parseStatement(parser, keyword, test.input) + if test.wantErr || (test.onlyPg && dialect == databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) { + if err == nil { + t.Fatalf("parseStatement(%q) should have failed", test.input) + } + } else { + if err != nil { + t.Fatal(err) + } + showStmt, ok := stmt.(*ParsedSetStatement) + if !ok { + t.Fatalf("parseStatement(%q) should have returned a *parsedSetStatement", test.input) + } + if !reflect.DeepEqual(*showStmt, test.want) { + t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + } } - if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) - } - } - }) + }) + } } } diff --git a/statements.go b/statements.go index f7232a7b..f8e06d02 100644 --- a/statements.go +++ b/statements.go @@ -112,14 +112,14 @@ type executableSetStatement struct { } func (s *executableSetStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - if err := c.setConnectionVariable(s.stmt.Identifier, s.stmt.Literal.Value, s.stmt.IsLocal); err != nil { + if err := c.setConnectionVariable(s.stmt.Identifier, s.stmt.Literal.Value, s.stmt.IsLocal, s.stmt.IsTransaction); err != nil { return nil, err } return driver.ResultNoRows, nil } func (s *executableSetStatement) queryContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Rows, error) { - if err := c.setConnectionVariable(s.stmt.Identifier, s.stmt.Literal.Value, s.stmt.IsLocal); err != nil { + if err := c.setConnectionVariable(s.stmt.Identifier, s.stmt.Literal.Value, s.stmt.IsLocal, s.stmt.IsTransaction); err != nil { return nil, err } return createEmptyRows(opts), nil @@ -131,14 +131,14 @@ type executableResetStatement struct { } func (s *executableResetStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - if err := c.setConnectionVariable(s.stmt.Identifier, "default", false); err != nil { + if err := c.setConnectionVariable(s.stmt.Identifier, "default", false, false); err != nil { return nil, err } return driver.ResultNoRows, nil } func (s *executableResetStatement) queryContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Rows, error) { - if err := c.setConnectionVariable(s.stmt.Identifier, "default", false); err != nil { + if err := c.setConnectionVariable(s.stmt.Identifier, "default", false, false); err != nil { return nil, err } return createEmptyRows(opts), nil diff --git a/transaction_test.go b/transaction_test.go new file mode 100644 index 00000000..d7c49ad2 --- /dev/null +++ b/transaction_test.go @@ -0,0 +1,107 @@ +package spannerdriver + +import ( + "context" + "database/sql" + "reflect" + "testing" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" +) + +func TestSetTransactionIsolationLevel(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + if _, err := tx.ExecContext(ctx, "set transaction isolation level repeatable read"); err != nil { + t.Fatal(err) + } + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _ = tx.Commit() + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatal("missing begin transaction on ExecuteSqlRequest") + } + if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), spannerpb.TransactionOptions_REPEATABLE_READ; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestSetTransactionReadOnly(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + if _, err := tx.ExecContext(ctx, "set transaction read only"); err != nil { + t.Fatal(err) + } + row := tx.QueryRowContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + if err := row.Err(); err != nil { + t.Fatal(err) + } + _ = tx.Commit() + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatal("missing begin transaction on ExecuteSqlRequest") + } + // TODO: Enable once transaction_read_only is picked up by the driver. + //readOnly := request.GetTransaction().GetBegin().GetReadOnly() + //if readOnly == nil { + // t.Fatal("missing readOnly on ExecuteSqlRequest") + //} +} + +func TestSetTransactionDeferrable(t *testing.T) { + t.Parallel() + + // SET TRANSACTION [NOT] DEFERRABLE is only supported for PostgreSQL-dialect databases. + db, _, teardown := setupTestDBConnectionWithParamsAndDialect(t, "", databasepb.DatabaseDialect_POSTGRESQL) + defer teardown() + ctx := context.Background() + + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + if _, err := tx.ExecContext(ctx, "set transaction deferrable"); err != nil { + t.Fatal(err) + } + row := tx.QueryRowContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + if err := row.Err(); err != nil { + t.Fatal(err) + } + + // transaction_deferrable is a no-op on Spanner, but the SQL statement is supported for + // PostgreSQL-dialect databases for compatibility reasons. + row = tx.QueryRowContext(ctx, "show transaction_deferrable") + if err := row.Err(); err != nil { + t.Fatal(err) + } + var deferrable bool + if err := row.Scan(&deferrable); err != nil { + t.Fatal(err) + } + _ = tx.Commit() + + if g, w := deferrable, true; g != w { + t.Fatalf("deferrable mismatch\n Got: %v\nWant: %v", g, w) + } +} From 9c7c46276759a2ed1dd4740ea39c69d2acb828a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sun, 5 Oct 2025 17:30:20 +0200 Subject: [PATCH 2/4] feat: support multiple transaction options in one statement --- parser/statements.go | 88 ++++++++++++++--------------- parser/statements_test.go | 113 +++++++++++++++++++++++++++----------- statements.go | 16 +++++- 3 files changed, 141 insertions(+), 76 deletions(-) diff --git a/parser/statements.go b/parser/statements.go index abfc765d..56307399 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -151,11 +151,19 @@ func (s *ParsedShowStatement) parse(parser *StatementParser, query string) error // activated. Examples include: // SET TRANSACTION READ ONLY // SET TRANSACTION ISOLATION LEVEL [SERIALIZABLE | REPEATABLE READ] +// +// One SET statement can set more than one property. type ParsedSetStatement struct { - query string - Identifier Identifier - Literal Literal - IsLocal bool + query string + // Identifiers contains the properties that are being set. The number of elements in this slice + // must be equal to the number of Literals. + Identifiers []Identifier + // Literals contains the values that should be set for the properties. + Literals []Literal + // IsLocal indicates whether this is a SET LOCAL statement or not. + IsLocal bool + // IsTransaction indicates whether this is a SET TRANSACTION statement or not. + // IsTransaction automatically also implies IsLocal. IsTransaction bool } @@ -208,8 +216,8 @@ func (s *ParsedSetStatement) parse(parser *StatementParser, query string) error return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) } s.query = query - s.Identifier = identifier - s.Literal = literalValue + s.Identifiers = []Identifier{identifier} + s.Literals = []Literal{literalValue} s.IsLocal = isLocal return nil } @@ -218,21 +226,33 @@ func (s *ParsedSetStatement) parseSetTransaction(sp *simpleParser, query string) if !sp.hasMoreTokens() { return status.Errorf(codes.InvalidArgument, "syntax error: missing TRANSACTION OPTION, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") } - // TODO: Support multiple transaction mode settings in one statement: - // SET TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE - if sp.peekKeyword("ISOLATION") { - return s.parseSetTransactionIsolationLevel(sp, query) - } - if sp.peekKeyword("READ") { - return s.parseSetTransactionMode(sp, query) - } - if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { - // https://www.postgresql.org/docs/current/sql-set-transaction.html - if sp.peekKeyword("DEFERRABLE") || sp.peekKeyword("NOT") { - return s.parseSetTransactionDeferrable(sp, query) + s.query = query + s.IsLocal = true + s.IsTransaction = true + + for { + if sp.peekKeyword("ISOLATION") { + if err := s.parseSetTransactionIsolationLevel(sp, query); err != nil { + return err + } + } else if sp.peekKeyword("READ") { + if err := s.parseSetTransactionMode(sp, query); err != nil { + return err + } + } else if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL && (sp.peekKeyword("DEFERRABLE") || sp.peekKeyword("NOT")) { + // https://www.postgresql.org/docs/current/sql-set-transaction.html + if err := s.parseSetTransactionDeferrable(sp, query); err != nil { + return err + } + } else { + return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + } + if !sp.hasMoreTokens() { + return nil } + // Eat and ignore any commas separating the various options. + sp.eatToken(',') } - return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") } func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, query string) error { @@ -247,15 +267,9 @@ func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, } else { return status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ") } - if sp.hasMoreTokens() { - return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) - } - s.query = query - s.Identifier = Identifier{Parts: []string{"isolation_level"}} - s.Literal = value - s.IsLocal = true - s.IsTransaction = true + s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"isolation_level"}}) + s.Literals = append(s.Literals, value) return nil } @@ -268,15 +282,9 @@ func (s *ParsedSetStatement) parseSetTransactionMode(sp *simpleParser, query str } else { return status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE") } - if sp.hasMoreTokens() { - return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) - } - s.query = query - s.Identifier = Identifier{Parts: []string{"transaction_read_only"}} - s.Literal = Literal{Value: fmt.Sprintf("%v", readOnly)} - s.IsLocal = true - s.IsTransaction = true + s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_read_only"}}) + s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", readOnly)}) return nil } @@ -289,15 +297,9 @@ func (s *ParsedSetStatement) parseSetTransactionDeferrable(sp *simpleParser, que } else { return status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE") } - if sp.hasMoreTokens() { - return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) - } - s.query = query - s.Identifier = Identifier{Parts: []string{"transaction_deferrable"}} - s.Literal = Literal{Value: fmt.Sprintf("%v", deferrable)} - s.IsLocal = true - s.IsTransaction = true + s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_deferrable"}}) + s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", deferrable)}) return nil } diff --git a/parser/statements_test.go b/parser/statements_test.go index b0718d57..91fada7e 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -136,58 +136,58 @@ func TestParseSetStatement(t *testing.T) { { input: "set my_property = 'foo'", want: ParsedSetStatement{ - query: "set my_property = 'foo'", - Identifier: Identifier{Parts: []string{"my_property"}}, - Literal: Literal{Value: "foo"}, + query: "set my_property = 'foo'", + Identifiers: []Identifier{{Parts: []string{"my_property"}}}, + Literals: []Literal{{Value: "foo"}}, }, }, { input: "set local my_property = 'foo'", want: ParsedSetStatement{ - query: "set local my_property = 'foo'", - Identifier: Identifier{Parts: []string{"my_property"}}, - Literal: Literal{Value: "foo"}, - IsLocal: true, + query: "set local my_property = 'foo'", + Identifiers: []Identifier{{Parts: []string{"my_property"}}}, + Literals: []Literal{{Value: "foo"}}, + IsLocal: true, }, }, { input: "set my_property = 1", want: ParsedSetStatement{ - query: "set my_property = 1", - Identifier: Identifier{Parts: []string{"my_property"}}, - Literal: Literal{Value: "1"}, + query: "set my_property = 1", + Identifiers: []Identifier{{Parts: []string{"my_property"}}}, + Literals: []Literal{{Value: "1"}}, }, }, { input: "set my_property = true", want: ParsedSetStatement{ - query: "set my_property = true", - Identifier: Identifier{Parts: []string{"my_property"}}, - Literal: Literal{Value: "true"}, + query: "set my_property = true", + Identifiers: []Identifier{{Parts: []string{"my_property"}}}, + Literals: []Literal{{Value: "true"}}, }, }, { input: "set \n -- comment \n my_property /* yet more comments */ = \ntrue/*comment*/ ", want: ParsedSetStatement{ - query: "set \n -- comment \n my_property /* yet more comments */ = \ntrue/*comment*/ ", - Identifier: Identifier{Parts: []string{"my_property"}}, - Literal: Literal{Value: "true"}, + query: "set \n -- comment \n my_property /* yet more comments */ = \ntrue/*comment*/ ", + Identifiers: []Identifier{{Parts: []string{"my_property"}}}, + Literals: []Literal{{Value: "true"}}, }, }, { input: "set \n -- comment \n a.b /* yet more comments */ =\n/*comment*/'value'/*comment*/ ", want: ParsedSetStatement{ - query: "set \n -- comment \n a.b /* yet more comments */ =\n/*comment*/'value'/*comment*/ ", - Identifier: Identifier{Parts: []string{"a", "b"}}, - Literal: Literal{Value: "value"}, + query: "set \n -- comment \n a.b /* yet more comments */ =\n/*comment*/'value'/*comment*/ ", + Identifiers: []Identifier{{Parts: []string{"a", "b"}}}, + Literals: []Literal{{Value: "value"}}, }, }, { input: "set transaction isolation level serializable", want: ParsedSetStatement{ query: "set transaction isolation level serializable", - Identifier: Identifier{Parts: []string{"isolation_level"}}, - Literal: Literal{Value: "serializable"}, + Identifiers: []Identifier{{Parts: []string{"isolation_level"}}}, + Literals: []Literal{{Value: "serializable"}}, IsLocal: true, IsTransaction: true, }, @@ -212,8 +212,8 @@ func TestParseSetStatement(t *testing.T) { input: "set transaction isolation level repeatable read", want: ParsedSetStatement{ query: "set transaction isolation level repeatable read", - Identifier: Identifier{Parts: []string{"isolation_level"}}, - Literal: Literal{Value: "repeatable_read"}, + Identifiers: []Identifier{{Parts: []string{"isolation_level"}}}, + Literals: []Literal{{Value: "repeatable_read"}}, IsLocal: true, IsTransaction: true, }, @@ -238,8 +238,59 @@ func TestParseSetStatement(t *testing.T) { input: "set transaction read write", want: ParsedSetStatement{ query: "set transaction read write", - Identifier: Identifier{Parts: []string{"transaction_read_only"}}, - Literal: Literal{Value: "false"}, + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "false"}}, + IsLocal: true, + IsTransaction: true, + }, + }, + { + input: "set transaction read write isolation level serializable", + want: ParsedSetStatement{ + query: "set transaction read write isolation level serializable", + Identifiers: []Identifier{ + {Parts: []string{"transaction_read_only"}}, + {Parts: []string{"isolation_level"}}, + }, + Literals: []Literal{ + {Value: "false"}, + {Value: "serializable"}, + }, + IsLocal: true, + IsTransaction: true, + }, + }, + { + input: "set transaction read only, isolation level repeatable read", + want: ParsedSetStatement{ + query: "set transaction read only, isolation level repeatable read", + Identifiers: []Identifier{ + {Parts: []string{"transaction_read_only"}}, + {Parts: []string{"isolation_level"}}, + }, + Literals: []Literal{ + {Value: "true"}, + {Value: "repeatable_read"}, + }, + IsLocal: true, + IsTransaction: true, + }, + }, + { + onlyPg: true, + input: "set transaction read only, isolation level repeatable read not deferrable", + want: ParsedSetStatement{ + query: "set transaction read only, isolation level repeatable read not deferrable", + Identifiers: []Identifier{ + {Parts: []string{"transaction_read_only"}}, + {Parts: []string{"isolation_level"}}, + {Parts: []string{"transaction_deferrable"}}, + }, + Literals: []Literal{ + {Value: "true"}, + {Value: "repeatable_read"}, + {Value: "false"}, + }, IsLocal: true, IsTransaction: true, }, @@ -248,8 +299,8 @@ func TestParseSetStatement(t *testing.T) { input: "set transaction read only", want: ParsedSetStatement{ query: "set transaction read only", - Identifier: Identifier{Parts: []string{"transaction_read_only"}}, - Literal: Literal{Value: "true"}, + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "true"}}, IsLocal: true, IsTransaction: true, }, @@ -275,8 +326,8 @@ func TestParseSetStatement(t *testing.T) { input: "set transaction deferrable", want: ParsedSetStatement{ query: "set transaction deferrable", - Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, - Literal: Literal{Value: "true"}, + Identifiers: []Identifier{{Parts: []string{"transaction_deferrable"}}}, + Literals: []Literal{{Value: "true"}}, IsLocal: true, IsTransaction: true, }, @@ -286,8 +337,8 @@ func TestParseSetStatement(t *testing.T) { input: "set transaction not deferrable", want: ParsedSetStatement{ query: "set transaction not deferrable", - Identifier: Identifier{Parts: []string{"transaction_deferrable"}}, - Literal: Literal{Value: "false"}, + Identifiers: []Identifier{{Parts: []string{"transaction_deferrable"}}}, + Literals: []Literal{{Value: "false"}}, IsLocal: true, IsTransaction: true, }, diff --git a/statements.go b/statements.go index f8e06d02..2c7c7a7a 100644 --- a/statements.go +++ b/statements.go @@ -112,19 +112,31 @@ type executableSetStatement struct { } func (s *executableSetStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { - if err := c.setConnectionVariable(s.stmt.Identifier, s.stmt.Literal.Value, s.stmt.IsLocal, s.stmt.IsTransaction); err != nil { + if err := s.execute(c); err != nil { return nil, err } return driver.ResultNoRows, nil } func (s *executableSetStatement) queryContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Rows, error) { - if err := c.setConnectionVariable(s.stmt.Identifier, s.stmt.Literal.Value, s.stmt.IsLocal, s.stmt.IsTransaction); err != nil { + if err := s.execute(c); err != nil { return nil, err } return createEmptyRows(opts), nil } +func (s *executableSetStatement) execute(c *conn) error { + if len(s.stmt.Identifiers) != len(s.stmt.Literals) { + return status.Errorf(codes.InvalidArgument, "statement contains %d identifiers, but %d values given", len(s.stmt.Identifiers), len(s.stmt.Literals)) + } + for index := range s.stmt.Identifiers { + if err := c.setConnectionVariable(s.stmt.Identifiers[index], s.stmt.Literals[index].Value, s.stmt.IsLocal, s.stmt.IsTransaction); err != nil { + return err + } + } + return nil +} + // RESET [my_extension.]my_property type executableResetStatement struct { stmt *parser.ParsedResetStatement From 1f4fbf3acd067669f6a3621d88763ef97966d9ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 6 Oct 2025 12:30:14 +0200 Subject: [PATCH 3/4] feat: support transaction options in BEGIN statements Adds support for including transaction options in BEGIN statements, like: ```sql BEGIN READ ONLY; BEGIN READ WRITE; BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ; BEGIN READ WRITE, ISOLATION LEVEL SERIALIZABLE; ``` --- parser/statements.go | 79 ++++++++++++++++--------- parser/statements_test.go | 86 ++++++++++++++++++++++++++- statements.go | 9 +++ transaction_test.go | 119 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 264 insertions(+), 29 deletions(-) diff --git a/parser/statements.go b/parser/statements.go index 56307399..8c05548b 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -230,34 +230,49 @@ func (s *ParsedSetStatement) parseSetTransaction(sp *simpleParser, query string) s.IsLocal = true s.IsTransaction = true + var err error + s.Identifiers, s.Literals, err = parseTransactionOptions(sp) + if err != nil { + return err + } + return nil +} + +func parseTransactionOptions(sp *simpleParser) ([]Identifier, []Literal, error) { + identifiers := make([]Identifier, 0, 2) + literals := make([]Literal, 0, 2) + var err error for { if sp.peekKeyword("ISOLATION") { - if err := s.parseSetTransactionIsolationLevel(sp, query); err != nil { - return err + identifiers, literals, err = parseTransactionIsolationLevel(sp, identifiers, literals) + if err != nil { + return nil, nil, err } } else if sp.peekKeyword("READ") { - if err := s.parseSetTransactionMode(sp, query); err != nil { - return err + identifiers, literals, err = parseTransactionMode(sp, identifiers, literals) + if err != nil { + return nil, nil, err } } else if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL && (sp.peekKeyword("DEFERRABLE") || sp.peekKeyword("NOT")) { // https://www.postgresql.org/docs/current/sql-set-transaction.html - if err := s.parseSetTransactionDeferrable(sp, query); err != nil { - return err + identifiers, literals, err = parseTransactionDeferrable(sp, identifiers, literals) + if err != nil { + return nil, nil, err } } else { - return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + return nil, nil, status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") } if !sp.hasMoreTokens() { - return nil + return identifiers, literals, nil } // Eat and ignore any commas separating the various options. sp.eatToken(',') } } -func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, query string) error { +func parseTransactionIsolationLevel(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) { if !sp.eatKeywords([]string{"ISOLATION", "LEVEL"}) { - return status.Errorf(codes.InvalidArgument, "syntax error: expected ISOLATION LEVEL") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected ISOLATION LEVEL") } var value Literal if sp.eatKeyword("SERIALIZABLE") { @@ -265,42 +280,42 @@ func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, } else if sp.eatKeywords([]string{"REPEATABLE", "READ"}) { value = Literal{Value: "repeatable_read"} } else { - return status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ") } - s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"isolation_level"}}) - s.Literals = append(s.Literals, value) - return nil + identifiers = append(identifiers, Identifier{Parts: []string{"isolation_level"}}) + literals = append(literals, value) + return identifiers, literals, nil } -func (s *ParsedSetStatement) parseSetTransactionMode(sp *simpleParser, query string) error { +func parseTransactionMode(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) { readOnly := false if sp.eatKeywords([]string{"READ", "ONLY"}) { readOnly = true } else if sp.eatKeywords([]string{"READ", "WRITE"}) { readOnly = false } else { - return status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE") } - s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_read_only"}}) - s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", readOnly)}) - return nil + identifiers = append(identifiers, Identifier{Parts: []string{"transaction_read_only"}}) + literals = append(literals, Literal{Value: fmt.Sprintf("%v", readOnly)}) + return identifiers, literals, nil } -func (s *ParsedSetStatement) parseSetTransactionDeferrable(sp *simpleParser, query string) error { +func parseTransactionDeferrable(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) { deferrable := false if sp.eatKeywords([]string{"NOT", "DEFERRABLE"}) { deferrable = false } else if sp.eatKeyword("DEFERRABLE") { deferrable = true } else { - return status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE") } - s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_deferrable"}}) - s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", deferrable)}) - return nil + identifiers = append(identifiers, Identifier{Parts: []string{"transaction_deferrable"}}) + literals = append(literals, Literal{Value: fmt.Sprintf("%v", deferrable)}) + return identifiers, literals, nil } // ParsedResetStatement is a statement of the form @@ -496,6 +511,12 @@ func (s *ParsedAbortBatchStatement) parse(parser *StatementParser, query string) type ParsedBeginStatement struct { query string + // Identifiers contains the transaction properties that were included in the BEGIN statement. E.g. the statement + // BEGIN TRANSACTION READ ONLY contains the transaction property 'transaction_read_only'. + Identifiers []Identifier + // Literals contains the transaction property values that were included in the BEGIN statement. E.g. the statement + // BEGIN TRANSACTION READ ONLY contains the value 'true' for the property 'transaction_read_only'. + Literals []Literal } func (s *ParsedBeginStatement) Name() string { @@ -508,9 +529,8 @@ func (s *ParsedBeginStatement) Query() string { func (s *ParsedBeginStatement) parse(parser *StatementParser, query string) error { // Parse a statement of the form - // GoogleSQL: BEGIN [TRANSACTION] + // GoogleSQL: BEGIN [TRANSACTION] [READ WRITE | READ ONLY | ISOLATION LEVEL {SERIALIZABLE | READ COMMITTED}] // PostgreSQL: {START | BEGIN} [{TRANSACTION | WORK}] (https://www.postgresql.org/docs/current/sql-begin.html) - // TODO: Support transaction modes in the BEGIN / START statement. sp := &simpleParser{sql: []byte(query), statementParser: parser} if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL { if !sp.eatKeyword("START") && !sp.eatKeyword("BEGIN") { @@ -531,8 +551,13 @@ func (s *ParsedBeginStatement) parse(parser *StatementParser, query string) erro } if sp.hasMoreTokens() { - return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + var err error + s.Identifiers, s.Literals, err = parseTransactionOptions(sp) + if err != nil { + return err + } } + s.query = query return nil } diff --git a/parser/statements_test.go b/parser/statements_test.go index 91fada7e..45361c74 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -431,6 +431,38 @@ func TestParseBeginStatementGoogleSQL(t *testing.T) { input: "begin transaction foo", wantErr: true, }, + { + input: "begin read only", + want: ParsedBeginStatement{ + query: "begin read only", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "true"}}, + }, + }, + { + input: "begin read write", + want: ParsedBeginStatement{ + query: "begin read write", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "false"}}, + }, + }, + { + input: "begin transaction isolation level serializable", + want: ParsedBeginStatement{ + query: "begin transaction isolation level serializable", + Identifiers: []Identifier{{Parts: []string{"isolation_level"}}}, + Literals: []Literal{{Value: "serializable"}}, + }, + }, + { + input: "begin transaction isolation level repeatable read, read write", + want: ParsedBeginStatement{ + query: "begin transaction isolation level repeatable read, read write", + Identifiers: []Identifier{{Parts: []string{"isolation_level"}}, {Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "repeatable_read"}, {Value: "false"}}, + }, + }, } parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) if err != nil { @@ -454,7 +486,7 @@ func TestParseBeginStatementGoogleSQL(t *testing.T) { t.Fatalf("parseStatement(%q) should have returned a *parsedBeginStatement", test.input) } if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + t.Errorf("parseStatement(%q) mismatch\n Got: %v\nWant: %v", test.input, *showStmt, test.want) } } }) @@ -506,6 +538,56 @@ func TestParseBeginStatementPostgreSQL(t *testing.T) { query: "start work", }, }, + { + input: "start work read only", + want: ParsedBeginStatement{ + query: "start work read only", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "true"}}, + }, + }, + { + input: "begin read write", + want: ParsedBeginStatement{ + query: "begin read write", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "false"}}, + }, + }, + { + input: "begin read write, isolation level repeatable read", + want: ParsedBeginStatement{ + query: "begin read write, isolation level repeatable read", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}, {Parts: []string{"isolation_level"}}}, + Literals: []Literal{{Value: "false"}, {Value: "repeatable_read"}}, + }, + }, + { + // Note that it is possible to set multiple conflicting transaction options in one statement. + // This statement for example sets the transaction to both read/write and read-only. + // The last option will take precedence, as these options are essentially the same as executing the + // following statements sequentially after the BEGIN TRANSACTION statement: + // SET TRANSACTION READ WRITE + // SET TRANSACTION ISOLATION LEVEL REPEATABLE READ + // SET TRANSACTION READ ONLY + // SET TRANSACTION DEFERRABLE + input: "begin transaction \nread write,\nisolation level repeatable read\nread only\ndeferrable", + want: ParsedBeginStatement{ + query: "begin transaction \nread write,\nisolation level repeatable read\nread only\ndeferrable", + Identifiers: []Identifier{ + {Parts: []string{"transaction_read_only"}}, + {Parts: []string{"isolation_level"}}, + {Parts: []string{"transaction_read_only"}}, + {Parts: []string{"transaction_deferrable"}}, + }, + Literals: []Literal{ + {Value: "false"}, + {Value: "repeatable_read"}, + {Value: "true"}, + {Value: "true"}, + }, + }, + }, { input: "start foo", wantErr: true, @@ -541,7 +623,7 @@ func TestParseBeginStatementPostgreSQL(t *testing.T) { t.Fatalf("parseStatement(%q) should have returned a *parsedBeginStatement", test.input) } if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + t.Errorf("parseStatement(%q) mismatch\n Got: %v\nWant: %v", test.input, *showStmt, test.want) } } }) diff --git a/statements.go b/statements.go index 2c7c7a7a..147146b5 100644 --- a/statements.go +++ b/statements.go @@ -279,10 +279,19 @@ type executableBeginStatement struct { } func (s *executableBeginStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { + if len(s.stmt.Identifiers) != len(s.stmt.Literals) { + return nil, status.Errorf(codes.InvalidArgument, "statement contains %d identifiers, but %d values given", len(s.stmt.Identifiers), len(s.stmt.Literals)) + } _, err := c.BeginTx(ctx, driver.TxOptions{}) if err != nil { return nil, err } + for index := range s.stmt.Identifiers { + if err := c.setConnectionVariable(s.stmt.Identifiers[index], s.stmt.Literals[index].Value /*IsLocal=*/, true /*IsTransaction=*/, true); err != nil { + return nil, err + } + } + return driver.ResultNoRows, nil } diff --git a/transaction_test.go b/transaction_test.go index d7c49ad2..0842279b 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -105,3 +105,122 @@ func TestSetTransactionDeferrable(t *testing.T) { t.Fatalf("deferrable mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestBeginTransactionIsolationLevel(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(conn) + + if _, err := conn.ExecContext(ctx, "begin transaction isolation level repeatable read"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "commit"); err != nil { + t.Fatal(err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatal("missing begin transaction on ExecuteSqlRequest") + } + if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), spannerpb.TransactionOptions_REPEATABLE_READ; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginTransactionReadOnly(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(conn) + + if _, err := conn.ExecContext(ctx, "begin transaction read write"); err != nil { + t.Fatal(err) + } + row := conn.QueryRowContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + var c int64 + // If we don't call row.Scan(..), then the underlying Rows object is not closed. That again means that the + // connection cannot be released. + if err := row.Scan(&c); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "commit"); err != nil { + t.Fatal(err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatal("missing begin transaction on ExecuteSqlRequest") + } + // TODO: Enable once transaction_read_only is picked up by the driver. + //readOnly := request.GetTransaction().GetBegin().GetReadOnly() + //if readOnly == nil { + // t.Fatal("missing readOnly on ExecuteSqlRequest") + //} +} + +func TestBeginTransactionDeferrable(t *testing.T) { + t.Parallel() + + // BEGIN TRANSACTION [NOT] DEFERRABLE is only supported for PostgreSQL-dialect databases. + db, _, teardown := setupTestDBConnectionWithParamsAndDialect(t, "", databasepb.DatabaseDialect_POSTGRESQL) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(conn) + + if _, err := conn.ExecContext(ctx, "begin transaction deferrable"); err != nil { + t.Fatal(err) + } + row := conn.QueryRowContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + var c int64 + if err := row.Scan(&c); err != nil { + t.Fatal(err) + } + + // transaction_deferrable is a no-op on Spanner, but the SQL statement is supported for + // PostgreSQL-dialect databases for compatibility reasons. + row = conn.QueryRowContext(ctx, "show transaction_deferrable") + var deferrable bool + if err := row.Scan(&deferrable); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "commit"); err != nil { + t.Fatal(err) + } + + if g, w := deferrable, true; g != w { + t.Fatalf("deferrable mismatch\n Got: %v\nWant: %v", g, w) + } +} From 70efce49c7a55618862c14c79e26700739ba9e5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 7 Oct 2025 19:29:12 +0200 Subject: [PATCH 4/4] chore: re-trigger checks