Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollback and savepoint support #415

Merged
merged 5 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ func (h *Handler) doQuery(

tdb, ok := database.(sql.TransactionDatabase)
if ok {
tx, err := tdb.BeginTransaction(ctx)
tx, err := tdb.StartTransaction(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -443,7 +443,7 @@ rowLoop:
if commitTransaction {
// TODO: unify this logic with Commit node
logrus.Tracef("committing transaction %s", tx)
if err := ctx.Session.CommitTransaction(ctx, getTransactionDbName(ctx)); err != nil {
if err := ctx.Session.CommitTransaction(ctx, getTransactionDbName(ctx), tx); err != nil {
return err
}
// Clearing out the current transaction will tell us to start a new one the next time this session queries
Expand Down
4 changes: 2 additions & 2 deletions sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ type Transaction interface {
type TransactionDatabase interface {
Database

// BeginTransaction starts a new transaction and returns it
BeginTransaction(ctx *Context) (Transaction, error)
// StartTransaction starts a new transaction and returns it
StartTransaction(ctx *Context) (Transaction, error)

// CommitTransaction commits the transaction given
CommitTransaction(ctx *Context, tx Transaction) error
Expand Down
4 changes: 4 additions & 0 deletions sql/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ var (

// ErrInvalidArgument is returned when an argument to a function is invalid.
ErrInvalidArgument = errors.NewKind("Incorrect arguments to %s")

// ErrSavepointDoesNotExist is returned when a RELEASE SAVEPOINT or ROLLBACK TO SAVEPOINT statement references a
// non-existent savepoint identifier
ErrSavepointDoesNotExist = errors.NewKind("SAVEPOINT %s does not exist")
)

func CastSQLError(err error) (*mysql.SQLError, bool) {
Expand Down
39 changes: 21 additions & 18 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
switch n := stmt.(type) {
default:
return nil, ErrUnsupportedSyntax.New(sqlparser.String(n))
case *sqlparser.BeginEndBlock:
return convertBeginEndBlock(ctx, n, query)
case *sqlparser.IfStatement:
return convertIfBlock(ctx, n)
case *sqlparser.Show:
// When a query is empty it means it comes from a subquery, as we don't
// have the query itself in a subquery. Hence, a SHOW could not be
Expand All @@ -166,12 +162,9 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
return nil, ErrUnsupportedFeature.New("SHOW in subquery")
}
return convertShow(ctx, n, query)
case *sqlparser.Explain:
return convertExplain(ctx, n)
case *sqlparser.Insert:
return convertInsert(ctx, n)
case *sqlparser.DDL:
// unlike other statements, DDL statements have loose parsing by default
// TODO: fix this
ddl, err := sqlparser.ParseStrictDDL(query)
if err != nil {
return nil, err
Expand All @@ -185,6 +178,16 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
return convertMultiAlterDDL(ctx, query, multiAlterDdl.(*sqlparser.MultiAlterDDL))
case *sqlparser.DBDDL:
return convertDBDDL(n)
case *sqlparser.Explain:
return convertExplain(ctx, n)
case *sqlparser.Insert:
return convertInsert(ctx, n)
case *sqlparser.Delete:
return convertDelete(ctx, n)
case *sqlparser.Update:
return convertUpdate(ctx, n)
case *sqlparser.Load:
return convertLoad(ctx, n)
case *sqlparser.Set:
return convertSet(ctx, n)
case *sqlparser.Use:
Expand All @@ -195,22 +198,22 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
return plan.NewCommit(""), nil
case *sqlparser.Rollback:
return plan.NewRollback(""), nil
case *sqlparser.Delete:
return convertDelete(ctx, n)
case *sqlparser.Update:
return convertUpdate(ctx, n)
case *sqlparser.Load:
return convertLoad(ctx, n)
case *sqlparser.Savepoint:
return plan.NewCreateSavepoint("", n.Identifier), nil
case *sqlparser.RollbackSavepoint:
return plan.NewRollbackSavepoint("", n.Identifier), nil
case *sqlparser.ReleaseSavepoint:
return plan.NewReleaseSavepoint("", n.Identifier), nil
case *sqlparser.BeginEndBlock:
return convertBeginEndBlock(ctx, n, query)
case *sqlparser.IfStatement:
return convertIfBlock(ctx, n)
case *sqlparser.Call:
return convertCall(ctx, n)
case *sqlparser.Declare:
return convertDeclare(ctx, n)
case *sqlparser.Signal:
return convertSignal(ctx, n)
//TODO: implement the SAVEPOINT statements used in transactions, currently here for integration compatibility
case *sqlparser.Savepoint, *sqlparser.RollbackSavepoint, *sqlparser.ReleaseSavepoint:
ctx.Warn(1642, "SAVEPOINT functionality is currently a no-op")
return plan.NewBlock(nil), nil // An empty block is essentially a no-op
}
}

Expand Down
6 changes: 6 additions & 0 deletions sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2366,7 +2366,13 @@ CREATE TABLE t2
),
showCollationProjection,
),
"BEGIN": plan.NewStartTransaction(""),
"START TRANSACTION": plan.NewStartTransaction(""),
"COMMIT": plan.NewCommit(""),
`ROLLBACK`: plan.NewRollback(""),
"SAVEPOINT abc": plan.NewCreateSavepoint("", "abc"),
"ROLLBACK TO SAVEPOINT abc": plan.NewRollbackSavepoint("", "abc"),
"RELEASE SAVEPOINT abc": plan.NewReleaseSavepoint("", "abc"),
"SHOW CREATE TABLE `mytable`": plan.NewShowCreateTable(plan.NewUnresolvedTable("mytable", ""), false),
"SHOW CREATE TABLE mytable": plan.NewShowCreateTable(plan.NewUnresolvedTable("mytable", ""), false),
"SHOW CREATE TABLE mydb.`mytable`": plan.NewShowCreateTable(plan.NewUnresolvedTable("mytable", "mydb"), false),
Expand Down
219 changes: 217 additions & 2 deletions sql/plan/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

package plan

import "github.com/dolthub/go-mysql-server/sql"
import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
)

// StartTransaction explicitly starts a transaction. Transactions also start before any statement execution that doesn't have a
// transaction.
Expand Down Expand Up @@ -63,7 +67,7 @@ func (s *StartTransaction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter,
}
}

transaction, err := tdb.BeginTransaction(ctx)
transaction, err := tdb.StartTransaction(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -246,6 +250,10 @@ func (r *Rollback) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
return nil, err
}

// Like Commit, Rollback ends the current transaction and a new one begins with the next statement
ctx.SetIgnoreAutoCommit(false)
ctx.SetTransaction(nil)

return sql.RowsToRowIter(), nil
}

Expand Down Expand Up @@ -280,3 +288,210 @@ func (*Rollback) Children() []sql.Node { return nil }

// Schema implements the sql.Node interface.
func (*Rollback) Schema() sql.Schema { return nil }

type CreateSavepoint struct {
name string
db sql.Database
}

var _ sql.Databaser = (*CreateSavepoint)(nil)
var _ sql.Node = (*CreateSavepoint)(nil)

// NewCreateSavepoint creates a new CreateSavepoint node.
func NewCreateSavepoint(db sql.UnresolvedDatabase, name string) *CreateSavepoint {
return &CreateSavepoint{
db: db,
name: name,
}
}

// RowIter implements the sql.Node interface.
func (c *CreateSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
tdb, ok := c.db.(sql.TransactionDatabase)
if !ok {
return sql.RowsToRowIter(), nil
}

transaction := ctx.GetTransaction()

if transaction == nil {
return sql.RowsToRowIter(), nil
}

err := tdb.CreateSavepoint(ctx, transaction, c.name)
if err != nil {
return nil, err
}

return sql.RowsToRowIter(), nil
}

func (c *CreateSavepoint) Database() sql.Database {
return c.db
}

func (c CreateSavepoint) WithDatabase(database sql.Database) (sql.Node, error) {
c.db = database
return &c, nil
}

func (c *CreateSavepoint) String() string { return fmt.Sprintf("SAVEPOINT %s", c.name) }

// WithChildren implements the Node interface.
func (c *CreateSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 0 {
return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0)
}

return c, nil
}

// Resolved implements the sql.Node interface.
func (c *CreateSavepoint) Resolved() bool {
_, ok := c.db.(sql.UnresolvedDatabase)
return !ok
}

// Children implements the sql.Node interface.
func (*CreateSavepoint) Children() []sql.Node { return nil }

// Schema implements the sql.Node interface.
func (*CreateSavepoint) Schema() sql.Schema { return nil }

type RollbackSavepoint struct {
name string
db sql.Database
}

var _ sql.Databaser = (*RollbackSavepoint)(nil)
var _ sql.Node = (*RollbackSavepoint)(nil)

// NewRollbackSavepoint creates a new RollbackSavepoint node.
func NewRollbackSavepoint(db sql.UnresolvedDatabase, name string) *RollbackSavepoint {
return &RollbackSavepoint{
db: db,
name: name,
}
}

// RowIter implements the sql.Node interface.
func (r *RollbackSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
tdb, ok := r.db.(sql.TransactionDatabase)
if !ok {
return sql.RowsToRowIter(), nil
}

transaction := ctx.GetTransaction()

if transaction == nil {
return sql.RowsToRowIter(), nil
}

err := tdb.RollbackToSavepoint(ctx, transaction, r.name)
if err != nil {
return nil, err
}

return sql.RowsToRowIter(), nil
}

func (r *RollbackSavepoint) Database() sql.Database {
return r.db
}

func (r RollbackSavepoint) WithDatabase(database sql.Database) (sql.Node, error) {
r.db = database
return &r, nil
}

func (r *RollbackSavepoint) String() string { return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s", r.name) }

// WithChildren implements the Node interface.
func (r *RollbackSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 0 {
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0)
}

return r, nil
}

// Resolved implements the sql.Node interface.
func (r *RollbackSavepoint) Resolved() bool {
_, ok := r.db.(sql.UnresolvedDatabase)
return !ok
}

// Children implements the sql.Node interface.
func (*RollbackSavepoint) Children() []sql.Node { return nil }

// Schema implements the sql.Node interface.
func (*RollbackSavepoint) Schema() sql.Schema { return nil }

type ReleaseSavepoint struct {
name string
db sql.Database
}

var _ sql.Databaser = (*ReleaseSavepoint)(nil)
var _ sql.Node = (*ReleaseSavepoint)(nil)

// NewReleaseSavepoint creates a new ReleaseSavepoint node.
func NewReleaseSavepoint(db sql.UnresolvedDatabase, name string) *ReleaseSavepoint {
return &ReleaseSavepoint{
db: db,
name: name,
}
}

// RowIter implements the sql.Node interface.
func (r *ReleaseSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
tdb, ok := r.db.(sql.TransactionDatabase)
if !ok {
return sql.RowsToRowIter(), nil
}

transaction := ctx.GetTransaction()

if transaction == nil {
return sql.RowsToRowIter(), nil
}

err := tdb.RollbackToSavepoint(ctx, transaction, r.name)
if err != nil {
return nil, err
}

return sql.RowsToRowIter(), nil
}

func (r *ReleaseSavepoint) Database() sql.Database {
return r.db
}

func (r ReleaseSavepoint) WithDatabase(database sql.Database) (sql.Node, error) {
r.db = database
return &r, nil
}

func (r *ReleaseSavepoint) String() string { return fmt.Sprintf("RELEASE SAVEPOINT %s", r.name) }

// WithChildren implements the Node interface.
func (r *ReleaseSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 0 {
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0)
}

return r, nil
}

// Resolved implements the sql.Node interface.
func (r *ReleaseSavepoint) Resolved() bool {
_, ok := r.db.(sql.UnresolvedDatabase)
return !ok
}

// Children implements the sql.Node interface.
func (*ReleaseSavepoint) Children() []sql.Node { return nil }

// Schema implements the sql.Node interface.
func (*ReleaseSavepoint) Schema() sql.Schema { return nil }
4 changes: 2 additions & 2 deletions sql/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ type Session interface {
// SetCurrentDatabase sets the current database for this session
SetCurrentDatabase(dbName string)
// CommitTransaction commits the current transaction for this session for the current database
CommitTransaction(ctx *Context, dbName string) error
CommitTransaction(ctx *Context, dbName string, transaction Transaction) error
// ID returns the unique ID of the connection.
ID() uint32
// Warn stores the warning in the session.
Expand Down Expand Up @@ -135,7 +135,7 @@ func (s *BaseSession) GetIgnoreAutoCommit() bool {
var _ Session = (*BaseSession)(nil)

// CommitTransaction commits the current transaction for the current database.
func (s *BaseSession) CommitTransaction(*Context, string) error {
func (s *BaseSession) CommitTransaction(*Context, string, Transaction) error {
// no-op on BaseSession
return nil
}
Expand Down