Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,25 @@ func (c *conn) inReadWriteTransaction() bool {
return false
}

func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
if !c.inTransaction() {
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
}
// TODO: Pass in context to the tx.Commit() function.
if err := c.tx.Commit(); err != nil {
return nil, err
}
return c.CommitResponse()
}

func (c *conn) rollback(ctx context.Context) error {
if !c.inTransaction() {
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
}
// TODO: Pass in context to the tx.Rollback() function.
return c.tx.Rollback()
}

func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
}
Expand Down
75 changes: 75 additions & 0 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/connectionstate"
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
Expand Down Expand Up @@ -108,6 +109,80 @@ func TestExplicitBeginTx(t *testing.T) {
}
}

func TestExecuteBegin(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()
ctx := context.Background()

for _, end := range []string{"rollback", "commit"} {
c, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
if _, err := c.ExecContext(ctx, "begin transaction"); err != nil {
t.Fatal(err)
}
if _, err := c.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil {
t.Fatal(err)
}
if _, err := c.ExecContext(ctx, end); err != nil {
t.Fatal(err)
}
if err := c.Close(); err != nil {
t.Fatal(err)
}

requests := drainRequestsFromServer(server.TestSpanner)
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 0; g != w {
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
}
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 1; g != w {
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil {
t.Fatal("missing begin transaction on ExecuteSqlRequest")
}
commitRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
rollbackRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{}))
if end == "commit" {
if g, w := len(commitRequests), 1; g != w {
t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w)
}
} else if end == "rollback" {
if g, w := len(rollbackRequests), 1; g != w {
t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w)
}
}
}
}

func TestEndTransactionWithoutBegin(t *testing.T) {
t.Parallel()

db, _, teardown := setupTestDBConnection(t)
defer teardown()
ctx := context.Background()

for _, end := range []string{"rollback", "commit"} {
c, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
_, err = c.ExecContext(ctx, end)
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}
if err := c.Close(); err != nil {
t.Fatal(err)
}
}
}

func TestBeginTxWithIsolationLevel(t *testing.T) {
t.Parallel()

Expand Down
13 changes: 10 additions & 3 deletions parser/simple_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,19 +282,26 @@ func (p *simpleParser) eatLiteral() (Literal, error) {
func (p *simpleParser) eatKeywords(keywords []string) bool {
startPos := p.pos
for _, keyword := range keywords {
if _, ok := p.eatKeyword(keyword); !ok {
if !p.eatKeyword(keyword) {
p.pos = startPos
return false
}
}
return true
}

// eatKeyword eats the given keyword at the current position of the parser if it exists.
// eatKeyword eats the given keyword at the current position of the parser if it exists
// and returns true if the keyword was found. Otherwise, it returns false.
func (p *simpleParser) eatKeyword(keyword string) bool {
_, ok := p.eatAndReturnKeyword(keyword)
return ok
}

// eatAndReturnKeyword eats the given keyword at the current position of the parser if it exists.
//
// Returns the actual keyword that was read and true if the keyword is found, and updates the position of the parser.
// Returns an empty string and false without updating the position of the parser if the keyword was not found.
func (p *simpleParser) eatKeyword(keyword string) (string, bool) {
func (p *simpleParser) eatAndReturnKeyword(keyword string) (string, bool) {
startPos := p.pos
found := p.readKeyword()
if !strings.EqualFold(found, keyword) {
Expand Down
34 changes: 26 additions & 8 deletions parser/statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ var updateStatements = map[string]bool{"UPDATE": true}
var deleteStatements = map[string]bool{"DELETE": true}
var dmlStatements = union(insertStatements, union(updateStatements, deleteStatements))
var clientSideKeywords = map[string]bool{
"SHOW": true,
"SET": true,
"RESET": true,
"START": true,
"RUN": true,
"ABORT": true,
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
"DROP": true, // DROP DATABASE is handled as a client-side statement
"SHOW": true,
"SET": true,
"RESET": true,
"START": true,
"RUN": true,
"ABORT": true,
"BEGIN": true,
"COMMIT": true,
"ROLLBACK": true,
"CREATE": true, // CREATE DATABASE is handled as a client-side statement
"DROP": true, // DROP DATABASE is handled as a client-side statement
}
var createStatements = map[string]bool{"CREATE": true}
var dropStatements = map[string]bool{"DROP": true}
Expand All @@ -52,6 +55,9 @@ var resetStatements = map[string]bool{"RESET": true}
var startStatements = map[string]bool{"START": true}
var runStatements = map[string]bool{"RUN": true}
var abortStatements = map[string]bool{"ABORT": true}
var beginStatements = map[string]bool{"BEGIN": true}
var commitStatements = map[string]bool{"COMMIT": true}
var rollbackStatements = map[string]bool{"ROLLBACK": true}

func union(m1 map[string]bool, m2 map[string]bool) map[string]bool {
res := make(map[string]bool, len(m1)+len(m2))
Expand Down Expand Up @@ -660,6 +666,18 @@ func isAbortStatementKeyword(keyword string) bool {
return isStatementKeyword(keyword, abortStatements)
}

func isBeginStatementKeyword(keyword string) bool {
return isStatementKeyword(keyword, beginStatements)
}

func isCommitStatementKeyword(keyword string) bool {
return isStatementKeyword(keyword, commitStatements)
}

func isRollbackStatementKeyword(keyword string) bool {
return isStatementKeyword(keyword, rollbackStatements)
}

func isStatementKeyword(keyword string, keywords map[string]bool) bool {
_, ok := keywords[keyword]
return ok
Expand Down
2 changes: 1 addition & 1 deletion parser/statement_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2467,7 +2467,7 @@ func TestEatKeyword(t *testing.T) {
for _, test := range tests {
sp := &simpleParser{sql: []byte(test.input), statementParser: parser}
startPos := sp.pos
keyword, ok := sp.eatKeyword(test.keyword)
keyword, ok := sp.eatAndReturnKeyword(test.keyword)
if g, w := ok, test.wantOk; g != w {
t.Errorf("found mismatch\n Got: %v\nWant: %v", g, w)
}
Expand Down
Loading
Loading