Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions sql/rowexec/ddl_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl
return false, err
}

rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions)
rowIter = withSafepointPeriodicallyIter(rowIter)
for {
r, err := rowIter.Next(ctx)
if err == io.EOF {
Expand Down Expand Up @@ -1117,7 +1118,8 @@ func (c *createPkIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) e
return err
}

rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions)
rowIter = withSafepointPeriodicallyIter(rowIter)

for {
r, err := rowIter.Next(ctx)
Expand Down Expand Up @@ -1221,7 +1223,8 @@ func (d *dropPkIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) err
return err
}

rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions)
rowIter = withSafepointPeriodicallyIter(rowIter)

for {
r, err := rowIter.Next(ctx)
Expand Down Expand Up @@ -1329,6 +1332,7 @@ func (i *addColumnIter) UpdateRowsWithDefaults(ctx *sql.Context, table sql.Table
if err != nil {
return err
}
tableIter = withSafepointPeriodicallyIter(tableIter)

schema := updatable.Schema()
idx := -1
Expand Down Expand Up @@ -1430,7 +1434,8 @@ func (i *addColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable)
return false, err
}

rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions)
rowIter = withSafepointPeriodicallyIter(rowIter)

var val uint64
var autoTbl sql.AutoIncrementTable
Expand Down Expand Up @@ -1740,7 +1745,8 @@ func (i *dropColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable)
return false, err
}

rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions)
rowIter = withSafepointPeriodicallyIter(rowIter)

for {
r, err := rowIter.Next(ctx)
Expand Down Expand Up @@ -2252,7 +2258,8 @@ func buildIndex(ctx *sql.Context, n *plan.AlterIndex, ibt sql.IndexBuildingTable
return err
}

rowIter := sql.NewTableRowIter(ctx, ibt, partitions)
var rowIter sql.RowIter = sql.NewTableRowIter(ctx, ibt, partitions)
rowIter = withSafepointPeriodicallyIter(rowIter)

// Our table scan needs to include projections for virtual columns if there are any
isVirtual := ibt.Schema().HasVirtualColumns()
Expand Down Expand Up @@ -2339,7 +2346,8 @@ func rewriteTableForIndexCreate(ctx *sql.Context, n *plan.AlterIndex, table sql.
return err
}

rowIter := sql.NewTableRowIter(ctx, rwt, partitions)
var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions)
rowIter = withSafepointPeriodicallyIter(rowIter)

isVirtual := table.Schema().HasVirtualColumns()
var projections []sql.Expression
Expand Down
48 changes: 44 additions & 4 deletions sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ func (i *triggerBlockIter) Next(ctx *sql.Context) (sql.Row, error) {
if err != nil {
return nil, err
}
subIter = withSafepointPeriodicallyIter(subIter)

for {
newRow, err := subIter.Next(ctx)
Expand All @@ -143,6 +144,7 @@ func (i *triggerBlockIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}
}
sql.SessionCommandSafepoint(ctx.Session)

return row, nil
}
Expand Down Expand Up @@ -264,6 +266,7 @@ func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
if err != nil {
return nil, err
}
logicIter = withSafepointPeriodicallyIter(logicIter)

defer func() {
err := logicIter.Close(t.ctx)
Expand Down Expand Up @@ -613,15 +616,15 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
switch innerIter := i.InnerIter().(type) {
case *insertIter:
if len(innerIter.returnExprs) > 0 {
return innerIter, innerIter.returnSchema
return withSafepointPeriodicallyIter(innerIter), innerIter.returnSchema
}
case *updateIter:
if len(innerIter.returnExprs) > 0 {
return innerIter, innerIter.returnSchema
return withSafepointPeriodicallyIter(innerIter), innerIter.returnSchema
}
case *deleteIter:
if len(innerIter.returnExprs) > 0 {
return innerIter, innerIter.returnSchema
return withSafepointPeriodicallyIter(innerIter), innerIter.returnSchema
}
}

Expand All @@ -631,6 +634,43 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
}
}

func withSafepointPeriodicallyIter(child sql.RowIter) *safepointPeriodicallyIter {
return &safepointPeriodicallyIter{child: child}
}

// A wrapper iterator which calls sql.SessionCommandSafepoint on the
// ctx.Session periodically while returning rows through calls to
// |Next|.
//
// Should be used to wrap any iterators which are involved in
// long-running write operations and which are exhausted or iterated
// by other iterators in the iterator tree, such as accumulatorIter.
//
// This iterator makes the assumption that a safepoint, from the
// Engine's perspective, can be established at any moment we are
// within a Next() call. This is generally true given the Engine's
// lack of concurrency on a given Session, but if something like
// Exchange node came back, this would not necessarily be true.
type safepointPeriodicallyIter struct {
child sql.RowIter
n int
}

const safepointEveryNRows = 1024

func (i *safepointPeriodicallyIter) Next(ctx *sql.Context) (r sql.Row, err error) {
i.n++
if i.n >= safepointEveryNRows {
i.n = 0
sql.SessionCommandSafepoint(ctx.Session)
}
return i.child.Next(ctx)
}

func (i *safepointPeriodicallyIter) Close(ctx *sql.Context) error {
return i.child.Close(ctx)
}

// defaultAccumulatorIter returns the default accumulator iter for a DML node
func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) {
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
Expand All @@ -639,7 +679,7 @@ func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sq
return iter, nil
}
return &accumulatorIter{
iter: iter,
iter: withSafepointPeriodicallyIter(iter),
updateRowHandler: rowHandler,
}, types.OkResultSchema
}
Expand Down
1 change: 1 addition & 0 deletions sql/rowexec/proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ func (b *BaseBuilder) buildLoop(ctx *sql.Context, n *plan.Loop, row sql.Row) (sq
return nil, err
}
}
loopBodyIter = withSafepointPeriodicallyIter(loopBodyIter)

includeResultSet := false

Expand Down
16 changes: 16 additions & 0 deletions sql/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ type LifecycleAwareSession interface {
SessionEnd()
}

// An optional Lifecycle callback which a session can receive. This can be
// delivered periodically during a long running operation, between the
// CommandBegin and CommandEnd calls. Across the call to this method, the
// gms.Engine is not accessing the session or any of its state, such as
// table editors, database providers, etc.
type SafepointAwareSession interface {
CommandSafepoint()
}

type (
// TypedValue is a value along with its type.
TypedValue struct {
Expand Down Expand Up @@ -763,6 +772,13 @@ func SessionCommandEnd(s Session) {
}
}

// Helper function to call CommandSafepoint on a SafepointAwareSession, or do nothing.
func SessionCommandSafepoint(s Session) {
if cur, ok := s.(SafepointAwareSession); ok {
cur.CommandSafepoint()
}
}

// Helper function to call SessionEnd on a LifecycleAwareSession, or do nothing.
func SessionEnd(s Session) {
if cur, ok := s.(LifecycleAwareSession); ok {
Expand Down