diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index 67e8a95951..ecbc16be97 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -570,7 +570,8 @@ func (i *modifyColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTabl return false, err } - rowIter := sql.NewTableRowIter(ctx, rwt, partitions) + var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions) + rowIter = withSafepointPeriodicallyIter(rowIter) for { r, err := rowIter.Next(ctx) if err == io.EOF { @@ -1117,7 +1118,8 @@ func (c *createPkIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) e return err } - rowIter := sql.NewTableRowIter(ctx, rwt, partitions) + var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions) + rowIter = withSafepointPeriodicallyIter(rowIter) for { r, err := rowIter.Next(ctx) @@ -1221,7 +1223,8 @@ func (d *dropPkIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) err return err } - rowIter := sql.NewTableRowIter(ctx, rwt, partitions) + var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions) + rowIter = withSafepointPeriodicallyIter(rowIter) for { r, err := rowIter.Next(ctx) @@ -1329,6 +1332,7 @@ func (i *addColumnIter) UpdateRowsWithDefaults(ctx *sql.Context, table sql.Table if err != nil { return err } + tableIter = withSafepointPeriodicallyIter(tableIter) schema := updatable.Schema() idx := -1 @@ -1430,7 +1434,8 @@ func (i *addColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) return false, err } - rowIter := sql.NewTableRowIter(ctx, rwt, partitions) + var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions) + rowIter = withSafepointPeriodicallyIter(rowIter) var val uint64 var autoTbl sql.AutoIncrementTable @@ -1740,7 +1745,8 @@ func (i *dropColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable) return false, err } - rowIter := sql.NewTableRowIter(ctx, rwt, partitions) + var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions) + rowIter = withSafepointPeriodicallyIter(rowIter) for { r, err := rowIter.Next(ctx) @@ -2252,7 +2258,8 @@ func buildIndex(ctx *sql.Context, n *plan.AlterIndex, ibt sql.IndexBuildingTable return err } - rowIter := sql.NewTableRowIter(ctx, ibt, partitions) + var rowIter sql.RowIter = sql.NewTableRowIter(ctx, ibt, partitions) + rowIter = withSafepointPeriodicallyIter(rowIter) // Our table scan needs to include projections for virtual columns if there are any isVirtual := ibt.Schema().HasVirtualColumns() @@ -2339,7 +2346,8 @@ func rewriteTableForIndexCreate(ctx *sql.Context, n *plan.AlterIndex, table sql. return err } - rowIter := sql.NewTableRowIter(ctx, rwt, partitions) + var rowIter sql.RowIter = sql.NewTableRowIter(ctx, rwt, partitions) + rowIter = withSafepointPeriodicallyIter(rowIter) isVirtual := table.Schema().HasVirtualColumns() var projections []sql.Expression diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index 0682976e4e..8b40e0f335 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -121,6 +121,7 @@ func (i *triggerBlockIter) Next(ctx *sql.Context) (sql.Row, error) { if err != nil { return nil, err } + subIter = withSafepointPeriodicallyIter(subIter) for { newRow, err := subIter.Next(ctx) @@ -143,6 +144,7 @@ func (i *triggerBlockIter) Next(ctx *sql.Context) (sql.Row, error) { } } } + sql.SessionCommandSafepoint(ctx.Session) return row, nil } @@ -264,6 +266,7 @@ func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) { if err != nil { return nil, err } + logicIter = withSafepointPeriodicallyIter(logicIter) defer func() { err := logicIter.Close(t.ctx) @@ -613,15 +616,15 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc switch innerIter := i.InnerIter().(type) { case *insertIter: if len(innerIter.returnExprs) > 0 { - return innerIter, innerIter.returnSchema + return withSafepointPeriodicallyIter(innerIter), innerIter.returnSchema } case *updateIter: if len(innerIter.returnExprs) > 0 { - return innerIter, innerIter.returnSchema + return withSafepointPeriodicallyIter(innerIter), innerIter.returnSchema } case *deleteIter: if len(innerIter.returnExprs) > 0 { - return innerIter, innerIter.returnSchema + return withSafepointPeriodicallyIter(innerIter), innerIter.returnSchema } } @@ -631,6 +634,43 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc } } +func withSafepointPeriodicallyIter(child sql.RowIter) *safepointPeriodicallyIter { + return &safepointPeriodicallyIter{child: child} +} + +// A wrapper iterator which calls sql.SessionCommandSafepoint on the +// ctx.Session periodically while returning rows through calls to +// |Next|. +// +// Should be used to wrap any iterators which are involved in +// long-running write operations and which are exhausted or iterated +// by other iterators in the iterator tree, such as accumulatorIter. +// +// This iterator makes the assumption that a safepoint, from the +// Engine's perspective, can be established at any moment we are +// within a Next() call. This is generally true given the Engine's +// lack of concurrency on a given Session, but if something like +// Exchange node came back, this would not necessarily be true. +type safepointPeriodicallyIter struct { + child sql.RowIter + n int +} + +const safepointEveryNRows = 1024 + +func (i *safepointPeriodicallyIter) Next(ctx *sql.Context) (r sql.Row, err error) { + i.n++ + if i.n >= safepointEveryNRows { + i.n = 0 + sql.SessionCommandSafepoint(ctx.Session) + } + return i.child.Next(ctx) +} + +func (i *safepointPeriodicallyIter) Close(ctx *sql.Context) error { + return i.child.Close(ctx) +} + // defaultAccumulatorIter returns the default accumulator iter for a DML node func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) { clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 @@ -639,7 +679,7 @@ func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sq return iter, nil } return &accumulatorIter{ - iter: iter, + iter: withSafepointPeriodicallyIter(iter), updateRowHandler: rowHandler, }, types.OkResultSchema } diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 35af6c65e8..877ba953f9 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -340,6 +340,7 @@ func (b *BaseBuilder) buildLoop(ctx *sql.Context, n *plan.Loop, row sql.Row) (sq return nil, err } } + loopBodyIter = withSafepointPeriodicallyIter(loopBodyIter) includeResultSet := false diff --git a/sql/session.go b/sql/session.go index de6ee992c8..ad655f2cc3 100644 --- a/sql/session.go +++ b/sql/session.go @@ -229,6 +229,15 @@ type LifecycleAwareSession interface { SessionEnd() } +// An optional Lifecycle callback which a session can receive. This can be +// delivered periodically during a long running operation, between the +// CommandBegin and CommandEnd calls. Across the call to this method, the +// gms.Engine is not accessing the session or any of its state, such as +// table editors, database providers, etc. +type SafepointAwareSession interface { + CommandSafepoint() +} + type ( // TypedValue is a value along with its type. TypedValue struct { @@ -763,6 +772,13 @@ func SessionCommandEnd(s Session) { } } +// Helper function to call CommandSafepoint on a SafepointAwareSession, or do nothing. +func SessionCommandSafepoint(s Session) { + if cur, ok := s.(SafepointAwareSession); ok { + cur.CommandSafepoint() + } +} + // Helper function to call SessionEnd on a LifecycleAwareSession, or do nothing. func SessionEnd(s Session) { if cur, ok := s.(LifecycleAwareSession); ok {