diff --git a/parser/statements.go b/parser/statements.go index 56307399..d45b01e3 100644 --- a/parser/statements.go +++ b/parser/statements.go @@ -230,34 +230,49 @@ func (s *ParsedSetStatement) parseSetTransaction(sp *simpleParser, query string) s.IsLocal = true s.IsTransaction = true + var err error + s.Identifiers, s.Literals, err = parseTransactionOptions(sp) + if err != nil { + return err + } + return nil +} + +func parseTransactionOptions(sp *simpleParser) ([]Identifier, []Literal, error) { + identifiers := make([]Identifier, 0, 2) + literals := make([]Literal, 0, 2) + var err error for { if sp.peekKeyword("ISOLATION") { - if err := s.parseSetTransactionIsolationLevel(sp, query); err != nil { - return err + identifiers, literals, err = parseTransactionIsolationLevel(sp, identifiers, literals) + if err != nil { + return nil, nil, err } } else if sp.peekKeyword("READ") { - if err := s.parseSetTransactionMode(sp, query); err != nil { - return err + identifiers, literals, err = parseTransactionMode(sp, identifiers, literals) + if err != nil { + return nil, nil, err } } else if sp.statementParser.Dialect == databasepb.DatabaseDialect_POSTGRESQL && (sp.peekKeyword("DEFERRABLE") || sp.peekKeyword("NOT")) { // https://www.postgresql.org/docs/current/sql-set-transaction.html - if err := s.parseSetTransactionDeferrable(sp, query); err != nil { - return err + identifiers, literals, err = parseTransactionDeferrable(sp, identifiers, literals) + if err != nil { + return nil, nil, err } } else { - return status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") + return nil, nil, status.Error(codes.InvalidArgument, "invalid TRANSACTION option, expected one of ISOLATION LEVEL, READ WRITE, or READ ONLY") } if !sp.hasMoreTokens() { - return nil + return identifiers, literals, nil } // Eat and ignore any commas separating the various options. sp.eatToken(',') } } -func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, query string) error { +func parseTransactionIsolationLevel(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) { if !sp.eatKeywords([]string{"ISOLATION", "LEVEL"}) { - return status.Errorf(codes.InvalidArgument, "syntax error: expected ISOLATION LEVEL") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected ISOLATION LEVEL") } var value Literal if sp.eatKeyword("SERIALIZABLE") { @@ -265,42 +280,42 @@ func (s *ParsedSetStatement) parseSetTransactionIsolationLevel(sp *simpleParser, } else if sp.eatKeywords([]string{"REPEATABLE", "READ"}) { value = Literal{Value: "repeatable_read"} } else { - return status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected SERIALIZABLE OR REPETABLE READ") } - s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"isolation_level"}}) - s.Literals = append(s.Literals, value) - return nil + identifiers = append(identifiers, Identifier{Parts: []string{"isolation_level"}}) + literals = append(literals, value) + return identifiers, literals, nil } -func (s *ParsedSetStatement) parseSetTransactionMode(sp *simpleParser, query string) error { +func parseTransactionMode(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) { readOnly := false if sp.eatKeywords([]string{"READ", "ONLY"}) { readOnly = true } else if sp.eatKeywords([]string{"READ", "WRITE"}) { readOnly = false } else { - return status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected READ ONLY or READ WRITE") } - s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_read_only"}}) - s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", readOnly)}) - return nil + identifiers = append(identifiers, Identifier{Parts: []string{"transaction_read_only"}}) + literals = append(literals, Literal{Value: fmt.Sprintf("%v", readOnly)}) + return identifiers, literals, nil } -func (s *ParsedSetStatement) parseSetTransactionDeferrable(sp *simpleParser, query string) error { +func parseTransactionDeferrable(sp *simpleParser, identifiers []Identifier, literals []Literal) ([]Identifier, []Literal, error) { deferrable := false if sp.eatKeywords([]string{"NOT", "DEFERRABLE"}) { deferrable = false } else if sp.eatKeyword("DEFERRABLE") { deferrable = true } else { - return status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE") + return nil, nil, status.Errorf(codes.InvalidArgument, "syntax error: expected [NOT] DEFERRABLE") } - s.Identifiers = append(s.Identifiers, Identifier{Parts: []string{"transaction_deferrable"}}) - s.Literals = append(s.Literals, Literal{Value: fmt.Sprintf("%v", deferrable)}) - return nil + identifiers = append(identifiers, Identifier{Parts: []string{"transaction_deferrable"}}) + literals = append(literals, Literal{Value: fmt.Sprintf("%v", deferrable)}) + return identifiers, literals, nil } // ParsedResetStatement is a statement of the form @@ -496,6 +511,12 @@ func (s *ParsedAbortBatchStatement) parse(parser *StatementParser, query string) type ParsedBeginStatement struct { query string + // Identifiers contains the transaction properties that were included in the BEGIN statement. E.g. the statement + // BEGIN TRANSACTION READ ONLY contains the transaction property 'transaction_read_only'. + Identifiers []Identifier + // Literals contains the transaction property values that were included in the BEGIN statement. E.g. the statement + // BEGIN TRANSACTION READ ONLY contains the value 'true' for the property 'transaction_read_only'. + Literals []Literal } func (s *ParsedBeginStatement) Name() string { @@ -508,7 +529,7 @@ func (s *ParsedBeginStatement) Query() string { func (s *ParsedBeginStatement) parse(parser *StatementParser, query string) error { // Parse a statement of the form - // GoogleSQL: BEGIN [TRANSACTION] + // GoogleSQL: BEGIN [TRANSACTION] [READ WRITE | READ ONLY | ISOLATION LEVEL {SERIALIZABLE | READ COMMITTED}] // PostgreSQL: {START | BEGIN} [{TRANSACTION | WORK}] (https://www.postgresql.org/docs/current/sql-begin.html) // TODO: Support transaction modes in the BEGIN / START statement. sp := &simpleParser{sql: []byte(query), statementParser: parser} @@ -531,8 +552,13 @@ func (s *ParsedBeginStatement) parse(parser *StatementParser, query string) erro } if sp.hasMoreTokens() { - return status.Errorf(codes.InvalidArgument, "unexpected tokens at position %d in %q", sp.pos, sp.sql) + var err error + s.Identifiers, s.Literals, err = parseTransactionOptions(sp) + if err != nil { + return err + } } + s.query = query return nil } diff --git a/parser/statements_test.go b/parser/statements_test.go index 91fada7e..45361c74 100644 --- a/parser/statements_test.go +++ b/parser/statements_test.go @@ -431,6 +431,38 @@ func TestParseBeginStatementGoogleSQL(t *testing.T) { input: "begin transaction foo", wantErr: true, }, + { + input: "begin read only", + want: ParsedBeginStatement{ + query: "begin read only", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "true"}}, + }, + }, + { + input: "begin read write", + want: ParsedBeginStatement{ + query: "begin read write", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "false"}}, + }, + }, + { + input: "begin transaction isolation level serializable", + want: ParsedBeginStatement{ + query: "begin transaction isolation level serializable", + Identifiers: []Identifier{{Parts: []string{"isolation_level"}}}, + Literals: []Literal{{Value: "serializable"}}, + }, + }, + { + input: "begin transaction isolation level repeatable read, read write", + want: ParsedBeginStatement{ + query: "begin transaction isolation level repeatable read, read write", + Identifiers: []Identifier{{Parts: []string{"isolation_level"}}, {Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "repeatable_read"}, {Value: "false"}}, + }, + }, } parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000) if err != nil { @@ -454,7 +486,7 @@ func TestParseBeginStatementGoogleSQL(t *testing.T) { t.Fatalf("parseStatement(%q) should have returned a *parsedBeginStatement", test.input) } if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + t.Errorf("parseStatement(%q) mismatch\n Got: %v\nWant: %v", test.input, *showStmt, test.want) } } }) @@ -506,6 +538,56 @@ func TestParseBeginStatementPostgreSQL(t *testing.T) { query: "start work", }, }, + { + input: "start work read only", + want: ParsedBeginStatement{ + query: "start work read only", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "true"}}, + }, + }, + { + input: "begin read write", + want: ParsedBeginStatement{ + query: "begin read write", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}}, + Literals: []Literal{{Value: "false"}}, + }, + }, + { + input: "begin read write, isolation level repeatable read", + want: ParsedBeginStatement{ + query: "begin read write, isolation level repeatable read", + Identifiers: []Identifier{{Parts: []string{"transaction_read_only"}}, {Parts: []string{"isolation_level"}}}, + Literals: []Literal{{Value: "false"}, {Value: "repeatable_read"}}, + }, + }, + { + // Note that it is possible to set multiple conflicting transaction options in one statement. + // This statement for example sets the transaction to both read/write and read-only. + // The last option will take precedence, as these options are essentially the same as executing the + // following statements sequentially after the BEGIN TRANSACTION statement: + // SET TRANSACTION READ WRITE + // SET TRANSACTION ISOLATION LEVEL REPEATABLE READ + // SET TRANSACTION READ ONLY + // SET TRANSACTION DEFERRABLE + input: "begin transaction \nread write,\nisolation level repeatable read\nread only\ndeferrable", + want: ParsedBeginStatement{ + query: "begin transaction \nread write,\nisolation level repeatable read\nread only\ndeferrable", + Identifiers: []Identifier{ + {Parts: []string{"transaction_read_only"}}, + {Parts: []string{"isolation_level"}}, + {Parts: []string{"transaction_read_only"}}, + {Parts: []string{"transaction_deferrable"}}, + }, + Literals: []Literal{ + {Value: "false"}, + {Value: "repeatable_read"}, + {Value: "true"}, + {Value: "true"}, + }, + }, + }, { input: "start foo", wantErr: true, @@ -541,7 +623,7 @@ func TestParseBeginStatementPostgreSQL(t *testing.T) { t.Fatalf("parseStatement(%q) should have returned a *parsedBeginStatement", test.input) } if !reflect.DeepEqual(*showStmt, test.want) { - t.Errorf("parseStatement(%q) = %v, want %v", test.input, *showStmt, test.want) + t.Errorf("parseStatement(%q) mismatch\n Got: %v\nWant: %v", test.input, *showStmt, test.want) } } }) diff --git a/statements.go b/statements.go index 2c7c7a7a..147146b5 100644 --- a/statements.go +++ b/statements.go @@ -279,10 +279,19 @@ type executableBeginStatement struct { } func (s *executableBeginStatement) execContext(ctx context.Context, c *conn, opts *ExecOptions) (driver.Result, error) { + if len(s.stmt.Identifiers) != len(s.stmt.Literals) { + return nil, status.Errorf(codes.InvalidArgument, "statement contains %d identifiers, but %d values given", len(s.stmt.Identifiers), len(s.stmt.Literals)) + } _, err := c.BeginTx(ctx, driver.TxOptions{}) if err != nil { return nil, err } + for index := range s.stmt.Identifiers { + if err := c.setConnectionVariable(s.stmt.Identifiers[index], s.stmt.Literals[index].Value /*IsLocal=*/, true /*IsTransaction=*/, true); err != nil { + return nil, err + } + } + return driver.ResultNoRows, nil } diff --git a/transaction_test.go b/transaction_test.go index d7c49ad2..0842279b 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -105,3 +105,122 @@ func TestSetTransactionDeferrable(t *testing.T) { t.Fatalf("deferrable mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestBeginTransactionIsolationLevel(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(conn) + + if _, err := conn.ExecContext(ctx, "begin transaction isolation level repeatable read"); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "commit"); err != nil { + t.Fatal(err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatal("missing begin transaction on ExecuteSqlRequest") + } + if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), spannerpb.TransactionOptions_REPEATABLE_READ; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginTransactionReadOnly(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(conn) + + if _, err := conn.ExecContext(ctx, "begin transaction read write"); err != nil { + t.Fatal(err) + } + row := conn.QueryRowContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + var c int64 + // If we don't call row.Scan(..), then the underlying Rows object is not closed. That again means that the + // connection cannot be released. + if err := row.Scan(&c); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "commit"); err != nil { + t.Fatal(err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatal("missing begin transaction on ExecuteSqlRequest") + } + // TODO: Enable once transaction_read_only is picked up by the driver. + //readOnly := request.GetTransaction().GetBegin().GetReadOnly() + //if readOnly == nil { + // t.Fatal("missing readOnly on ExecuteSqlRequest") + //} +} + +func TestBeginTransactionDeferrable(t *testing.T) { + t.Parallel() + + // BEGIN TRANSACTION [NOT] DEFERRABLE is only supported for PostgreSQL-dialect databases. + db, _, teardown := setupTestDBConnectionWithParamsAndDialect(t, "", databasepb.DatabaseDialect_POSTGRESQL) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(conn) + + if _, err := conn.ExecContext(ctx, "begin transaction deferrable"); err != nil { + t.Fatal(err) + } + row := conn.QueryRowContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + var c int64 + if err := row.Scan(&c); err != nil { + t.Fatal(err) + } + + // transaction_deferrable is a no-op on Spanner, but the SQL statement is supported for + // PostgreSQL-dialect databases for compatibility reasons. + row = conn.QueryRowContext(ctx, "show transaction_deferrable") + var deferrable bool + if err := row.Scan(&deferrable); err != nil { + t.Fatal(err) + } + if _, err := conn.ExecContext(ctx, "commit"); err != nil { + t.Fatal(err) + } + + if g, w := deferrable, true; g != w { + t.Fatalf("deferrable mismatch\n Got: %v\nWant: %v", g, w) + } +}