diff --git a/.golangci.yml b/.golangci.yml index 5dae41597902c..5858c0e6a5453 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -41,6 +41,8 @@ linters-settings: forbid: - '^fmt\.Errorf(# use errors\.Errorf instead)?$' - '^logrus\.(Trace|Debug|Info|Warn|Warning|Error|Fatal)(f|ln)?(# use bklog\.G or bklog\.L instead of logrus directly)?$' + - '^context\.WithCancel(# use context\.WithCancelCause instead)?$' + - '^ctx\.Err(# use context\.Cause instead)?$' importas: alias: - pkg: "github.com/opencontainers/image-spec/specs-go/v1" diff --git a/cache/manager.go b/cache/manager.go index c22dd3d16dcc6..c09ada3b99195 100644 --- a/cache/manager.go +++ b/cache/manager.go @@ -1258,7 +1258,7 @@ func (cm *cacheManager) prune(ctx context.Context, ch chan client.UsageInfo, opt select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) default: return cm.prune(ctx, ch, opt) } diff --git a/client/build_test.go b/client/build_test.go index 7611f140fecc2..7b325cb05dd65 100644 --- a/client/build_test.go +++ b/client/build_test.go @@ -1266,8 +1266,8 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo defer pid1.Wait() defer ctr.Release(ctx) - execCtx, cancel := context.WithCancel(ctx) - defer cancel() + execCtx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) prompt := newTestPrompt(execCtx, t, inputW, output) pid2, err := ctr.Start(execCtx, client.StartRequest{ @@ -1281,7 +1281,7 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo require.NoError(t, err) prompt.SendExpect("echo hi", "hi") - cancel() + cancel(errors.WithStack(context.Canceled)) err = pid2.Wait() require.ErrorIs(t, err, context.Canceled) diff --git a/client/client.go b/client/client.go index 71a72bf9f6a53..ea8b0e0e0b653 100644 --- a/client/client.go +++ b/client/client.go @@ -205,7 +205,7 @@ func (c *Client) Wait(ctx context.Context) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-time.After(time.Second): } c.conn.ResetConnectBackoff() diff --git a/client/client_test.go b/client/client_test.go index 49b2fe7b6752c..dafe667188880 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -7407,8 +7407,8 @@ func testInvalidExporter(t *testing.T, sb integration.Sandbox) { // moby/buildkit#492 func testParallelLocalBuilds(t *testing.T, sb integration.Sandbox) { - ctx, cancel := context.WithCancel(sb.Context()) - defer cancel() + ctx, cancel := context.WithCancelCause(sb.Context()) + defer cancel(errors.WithStack(context.Canceled)) c, err := New(ctx, sb.Address()) require.NoError(t, err) diff --git a/client/llb/async.go b/client/llb/async.go index 8771c71978f89..cadbb5ef363ea 100644 --- a/client/llb/async.go +++ b/client/llb/async.go @@ -61,7 +61,7 @@ func (as *asyncState) Do(ctx context.Context, c *Constraints) error { if err != nil { select { case <-ctx.Done(): - if errors.Is(err, ctx.Err()) { + if errors.Is(err, context.Cause(ctx)) { return res, err } default: diff --git a/client/solve.go b/client/solve.go index 04090ad1400f2..1c6f1489ed9be 100644 --- a/client/solve.go +++ b/client/solve.go @@ -106,8 +106,8 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG } eg, ctx := errgroup.WithContext(ctx) - statusContext, cancelStatus := context.WithCancel(context.Background()) - defer cancelStatus() + statusContext, cancelStatus := context.WithCancelCause(context.Background()) + defer cancelStatus(errors.WithStack(context.Canceled)) if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() { statusContext = trace.ContextWithSpan(statusContext, span) @@ -230,16 +230,16 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG frontendAttrs[k] = v } - solveCtx, cancelSolve := context.WithCancel(ctx) + solveCtx, cancelSolve := context.WithCancelCause(ctx) var res *SolveResponse eg.Go(func() error { ctx := solveCtx - defer cancelSolve() + defer cancelSolve(errors.WithStack(context.Canceled)) defer func() { // make sure the Status ends cleanly on build errors go func() { <-time.After(3 * time.Second) - cancelStatus() + cancelStatus(errors.WithStack(context.Canceled)) }() if !opt.SessionPreInitialized { bklog.G(ctx).Debugf("stopping session") @@ -298,7 +298,7 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG select { case <-solveCtx.Done(): case <-time.After(5 * time.Second): - cancelSolve() + cancelSolve(errors.WithStack(context.Canceled)) } return err diff --git a/cmd/buildkitd/main.go b/cmd/buildkitd/main.go index f80c62d9a399a..24ffa2cce3421 100644 --- a/cmd/buildkitd/main.go +++ b/cmd/buildkitd/main.go @@ -223,8 +223,8 @@ func main() { if os.Geteuid() > 0 { return errors.New("rootless mode requires to be executed as the mapped root in a user namespace; you may use RootlessKit for setting up the namespace") } - ctx, cancel := context.WithCancel(appcontext.Context()) - defer cancel() + ctx, cancel := context.WithCancelCause(appcontext.Context()) + defer cancel(errors.WithStack(context.Canceled)) cfg, err := config.LoadFile(c.GlobalString("config")) if err != nil { @@ -344,9 +344,9 @@ func main() { select { case serverErr := <-errCh: err = serverErr - cancel() + cancel(err) case <-ctx.Done(): - err = ctx.Err() + err = context.Cause(ctx) } bklog.G(ctx).Infof("stopping server") @@ -634,14 +634,14 @@ func unaryInterceptor(globalCtx context.Context, tp trace.TracerProvider) grpc.U withTrace := otelgrpc.UnaryServerInterceptor(otelgrpc.WithTracerProvider(tp), otelgrpc.WithPropagators(propagators)) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) go func() { select { case <-ctx.Done(): case <-globalCtx.Done(): - cancel() + cancel(context.Cause(globalCtx)) } }() diff --git a/control/control.go b/control/control.go index 34c88bc1b4b89..3d6409c529b88 100644 --- a/control/control.go +++ b/control/control.go @@ -505,10 +505,10 @@ func (c *Controller) Session(stream controlapi.Control_SessionServer) error { conn, closeCh, opts := grpchijack.Hijack(stream) defer conn.Close() - ctx, cancel := context.WithCancel(stream.Context()) + ctx, cancel := context.WithCancelCause(stream.Context()) go func() { <-closeCh - cancel() + cancel(errors.WithStack(context.Canceled)) }() err := c.opt.SessionManager.HandleConn(ctx, conn, opts) diff --git a/executor/containerdexecutor/executor.go b/executor/containerdexecutor/executor.go index 8347c2001bb0f..5089bab0ead2c 100644 --- a/executor/containerdexecutor/executor.go +++ b/executor/containerdexecutor/executor.go @@ -243,7 +243,7 @@ func (w *containerdExecutor) Exec(ctx context.Context, id string, process execut } select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case err, ok := <-details.done: if !ok || err == nil { return errors.Errorf("container %s has stopped", id) @@ -336,8 +336,8 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces // handle signals (and resize) in separate go loop so it does not // potentially block the container cancel/exit status loop below. - eventCtx, eventCancel := context.WithCancel(ctx) - defer eventCancel() + eventCtx, eventCancel := context.WithCancelCause(ctx) + defer eventCancel(errors.WithStack(context.Canceled)) go func() { for { select { @@ -403,7 +403,7 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces } select { case <-ctx.Done(): - exitErr.Err = errors.Wrap(ctx.Err(), exitErr.Error()) + exitErr.Err = errors.Wrap(context.Cause(ctx), exitErr.Error()) default: } return exitErr diff --git a/executor/runcexecutor/executor.go b/executor/runcexecutor/executor.go index e804ee850b28f..702419a4db0ea 100644 --- a/executor/runcexecutor/executor.go +++ b/executor/runcexecutor/executor.go @@ -369,7 +369,7 @@ func exitError(ctx context.Context, err error) error { ) select { case <-ctx.Done(): - exitErr.Err = errors.Wrapf(ctx.Err(), exitErr.Error()) + exitErr.Err = errors.Wrapf(context.Cause(ctx), exitErr.Error()) return exitErr default: return stack.Enable(exitErr) @@ -402,7 +402,7 @@ func (w *runcExecutor) Exec(ctx context.Context, id string, process executor.Pro } select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case err, ok := <-done: if !ok || err == nil { return errors.Errorf("container %s has stopped", id) @@ -580,7 +580,7 @@ type procHandle struct { monitorProcess *os.Process ready chan struct{} ended chan struct{} - shutdown func() + shutdown func(error) // this this only used when the request context is canceled and we need // to kill the in-container process. killer procKiller @@ -594,7 +594,7 @@ type procHandle struct { // The goal is to allow for runc to gracefully shutdown when the request context // is cancelled. func runcProcessHandle(ctx context.Context, killer procKiller) (*procHandle, context.Context) { - runcCtx, cancel := context.WithCancel(context.Background()) + runcCtx, cancel := context.WithCancelCause(context.Background()) p := &procHandle{ ready: make(chan struct{}), ended: make(chan struct{}), @@ -620,7 +620,7 @@ func runcProcessHandle(ctx context.Context, killer procKiller) (*procHandle, con select { case <-killCtx.Done(): timeout() - cancel() + cancel(errors.WithStack(context.Cause(ctx))) return default: } @@ -653,7 +653,7 @@ func (p *procHandle) Release() { // goroutines. func (p *procHandle) Shutdown() { if p.shutdown != nil { - p.shutdown() + p.shutdown(errors.WithStack(context.Canceled)) } } @@ -663,7 +663,7 @@ func (p *procHandle) Shutdown() { func (p *procHandle) WaitForReady(ctx context.Context) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-p.ready: return nil } diff --git a/frontend/gateway/container/container.go b/frontend/gateway/container/container.go index af6476e7fce2b..8b876890c2922 100644 --- a/frontend/gateway/container/container.go +++ b/frontend/gateway/container/container.go @@ -48,7 +48,7 @@ type Mount struct { } func NewContainer(ctx context.Context, w worker.Worker, sm *session.Manager, g session.Group, req NewContainerRequest) (client.Container, error) { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) eg, ctx := errgroup.WithContext(ctx) platform := opspb.Platform{ OS: runtime.GOOS, @@ -300,7 +300,7 @@ type gatewayContainer struct { mu sync.Mutex cleanup []func() error ctx context.Context - cancel func() + cancel func(error) } func (gwCtr *gatewayContainer) Start(ctx context.Context, req client.StartRequest) (client.ContainerProcess, error) { @@ -408,7 +408,7 @@ func (gwCtr *gatewayContainer) loadSecretEnv(ctx context.Context, secretEnv []*p func (gwCtr *gatewayContainer) Release(ctx context.Context) error { gwCtr.mu.Lock() defer gwCtr.mu.Unlock() - gwCtr.cancel() + gwCtr.cancel(errors.WithStack(context.Canceled)) err1 := gwCtr.errGroup.Wait() var err2 error diff --git a/frontend/gateway/gateway.go b/frontend/gateway/gateway.go index 3b23386fc8e4d..24b645f13bc3e 100644 --- a/frontend/gateway/gateway.go +++ b/frontend/gateway/gateway.go @@ -456,7 +456,7 @@ func newBridgeForwarder(ctx context.Context, llbBridge frontend.FrontendLLBBridg } func serveLLBBridgeForwarder(ctx context.Context, llbBridge frontend.FrontendLLBBridge, workers worker.Infos, inputs map[string]*opspb.Definition, sid string, sm *session.Manager) (*llbBridgeForwarder, context.Context, error) { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) lbf := newBridgeForwarder(ctx, llbBridge, workers, inputs, sid, sm) server := grpc.NewServer(grpc.UnaryInterceptor(grpcerrors.UnaryServerInterceptor), grpc.StreamInterceptor(grpcerrors.StreamServerInterceptor)) grpc_health_v1.RegisterHealthServer(server, health.NewServer()) @@ -469,7 +469,7 @@ func serveLLBBridgeForwarder(ctx context.Context, llbBridge frontend.FrontendLLB default: lbf.isErrServerClosed = true } - cancel() + cancel(errors.WithStack(context.Canceled)) }() return lbf, ctx, nil @@ -1322,8 +1322,8 @@ func (lbf *llbBridgeForwarder) ExecProcess(srv pb.LLBBridge_ExecProcessServer) e return stack.Enable(status.Errorf(codes.NotFound, "container %q previously released or not created", id)) } - initCtx, initCancel := context.WithCancel(context.Background()) - defer initCancel() + initCtx, initCancel := context.WithCancelCause(context.Background()) + defer initCancel(errors.WithStack(context.Canceled)) pio := newProcessIO(pid, init.Fds) pios[pid] = pio diff --git a/frontend/gateway/grpcclient/client.go b/frontend/gateway/grpcclient/client.go index 524b3ba2a966b..56322b2785e83 100644 --- a/frontend/gateway/grpcclient/client.go +++ b/frontend/gateway/grpcclient/client.go @@ -616,7 +616,7 @@ func (b *procMessageForwarder) Close() { type messageForwarder struct { client pb.LLBBridgeClient ctx context.Context - cancel func() + cancel func(error) eg *errgroup.Group mu sync.Mutex pids map[string]*procMessageForwarder @@ -630,7 +630,7 @@ type messageForwarder struct { } func newMessageForwarder(ctx context.Context, client pb.LLBBridgeClient) *messageForwarder { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) eg, ctx := errgroup.WithContext(ctx) return &messageForwarder{ client: client, @@ -719,7 +719,7 @@ func (m *messageForwarder) Send(msg *pb.ExecMessage) error { } func (m *messageForwarder) Release() error { - m.cancel() + m.cancel(errors.WithStack(context.Canceled)) return m.eg.Wait() } @@ -949,7 +949,7 @@ func (ctr *container) Start(ctx context.Context, req client.StartRequest) (clien closeDoneOnce.Do(func() { close(done) }) - return ctx.Err() + return context.Cause(ctx) } if file := msg.GetFile(); file != nil { @@ -1145,7 +1145,7 @@ func grpcClientConn(ctx context.Context) (context.Context, *grpc.ClientConn, err return nil, nil, errors.Wrap(err, "failed to create grpc client") } - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) _ = cancel // go monitorHealth(ctx, cc, cancel) diff --git a/session/filesync/filesync.go b/session/filesync/filesync.go index 7254ddc08a39d..f05c475f6dcf3 100644 --- a/session/filesync/filesync.go +++ b/session/filesync/filesync.go @@ -195,8 +195,8 @@ func FSSync(ctx context.Context, c session.Caller, opt FSSendRequestOpt) error { opts[keyDirName] = []string{opt.Name} - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) client := NewFileSyncClient(c.Conn()) diff --git a/session/grpc.go b/session/grpc.go index 0e475199acaf9..b16e927aa856b 100644 --- a/session/grpc.go +++ b/session/grpc.go @@ -74,14 +74,14 @@ func grpcClientConn(ctx context.Context, conn net.Conn) (context.Context, *grpc. return nil, nil, errors.Wrap(err, "failed to create grpc client") } - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) go monitorHealth(ctx, cc, cancel) return ctx, cc, nil } -func monitorHealth(ctx context.Context, cc *grpc.ClientConn, cancelConn func()) { - defer cancelConn() +func monitorHealth(ctx context.Context, cc *grpc.ClientConn, cancelConn func(error)) { + defer cancelConn(errors.WithStack(context.Canceled)) defer cc.Close() ticker := time.NewTicker(5 * time.Second) diff --git a/session/manager.go b/session/manager.go index 2678e6738dab5..2eda89d2be075 100644 --- a/session/manager.go +++ b/session/manager.go @@ -99,8 +99,8 @@ func (sm *Manager) HandleConn(ctx context.Context, conn net.Conn, opts map[strin // caller needs to take lock, this function will release it func (sm *Manager) handleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) opts = canonicalHeaders(opts) @@ -156,8 +156,8 @@ func (sm *Manager) Get(ctx context.Context, id string, noWait bool) (Caller, err id = p[1] } - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) go func() { <-ctx.Done() @@ -173,7 +173,7 @@ func (sm *Manager) Get(ctx context.Context, id string, noWait bool) (Caller, err select { case <-ctx.Done(): sm.mu.Unlock() - return nil, errors.Wrapf(ctx.Err(), "no active session for %s", id) + return nil, errors.Wrapf(context.Cause(ctx), "no active session for %s", id) default: } var ok bool diff --git a/session/session.go b/session/session.go index f56a18730d22e..f9a56b88f9703 100644 --- a/session/session.go +++ b/session/session.go @@ -42,7 +42,7 @@ type Session struct { name string sharedKey string ctx context.Context - cancelCtx func() + cancelCtx func(error) done chan struct{} grpcServer *grpc.Server conn net.Conn @@ -107,11 +107,11 @@ func (s *Session) Run(ctx context.Context, dialer Dialer) error { s.mu.Unlock() return nil } - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) s.cancelCtx = cancel s.done = make(chan struct{}) - defer cancel() + defer cancel(errors.WithStack(context.Canceled)) defer close(s.done) meta := make(map[string][]string) diff --git a/session/sshforward/copy.go b/session/sshforward/copy.go index eac5f7614a7d3..804debd16df41 100644 --- a/session/sshforward/copy.go +++ b/session/sshforward/copy.go @@ -39,7 +39,7 @@ func Copy(ctx context.Context, conn io.ReadWriteCloser, stream Stream, closeStre select { case <-ctx.Done(): conn.Close() - return ctx.Err() + return context.Cause(ctx) default: } if _, err := conn.Write(p.Data); err != nil { @@ -65,7 +65,7 @@ func Copy(ctx context.Context, conn io.ReadWriteCloser, stream Stream, closeStre } select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) default: } p := &BytesMessage{Data: buf[:n]} diff --git a/session/sshforward/ssh.go b/session/sshforward/ssh.go index a808fcb1f077d..8a041b311f769 100644 --- a/session/sshforward/ssh.go +++ b/session/sshforward/ssh.go @@ -26,7 +26,7 @@ func (s *server) run(ctx context.Context, l net.Listener, id string) error { eg.Go(func() error { <-ctx.Done() - return ctx.Err() + return context.Cause(ctx) }) eg.Go(func() error { diff --git a/snapshot/diffapply_unix.go b/snapshot/diffapply_unix.go index 5aa73dd0a7282..72395d33a668a 100644 --- a/snapshot/diffapply_unix.go +++ b/snapshot/diffapply_unix.go @@ -600,8 +600,10 @@ func (d *differ) doubleWalkingChanges(ctx context.Context, handle func(context.C if prevErr != nil { return prevErr } - if ctx.Err() != nil { - return ctx.Err() + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: } if kind == fs.ChangeKindUnmodified { @@ -689,8 +691,11 @@ func (d *differ) overlayChanges(ctx context.Context, handle func(context.Context if prevErr != nil { return prevErr } - if ctx.Err() != nil { - return ctx.Err() + + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: } if kind == fs.ChangeKindUnmodified { diff --git a/solver/errdefs/context.go b/solver/errdefs/context.go index 9e0c5bb990c67..68779c444674f 100644 --- a/solver/errdefs/context.go +++ b/solver/errdefs/context.go @@ -14,7 +14,7 @@ func IsCanceled(ctx context.Context, err error) bool { return true } // grpc does not set cancel correctly when stream gets cancelled and then Recv is called - if err != nil && ctx.Err() == context.Canceled { + if err != nil && context.Cause(ctx) == context.Canceled { // when this error comes from containerd it is not typed at all, just concatenated string if strings.Contains(err.Error(), "EOF") { return true diff --git a/solver/internal/pipe/pipe.go b/solver/internal/pipe/pipe.go index a1a857f39827c..01cfa5a06b6dc 100644 --- a/solver/internal/pipe/pipe.go +++ b/solver/internal/pipe/pipe.go @@ -70,11 +70,11 @@ type Status struct { func NewWithFunction(f func(context.Context) (interface{}, error)) (*Pipe, func()) { p := New(Request{}) - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancelCause(context.TODO()) p.OnReceiveCompletion = func() { if req := p.Sender.Request(); req.Canceled { - cancel() + cancel(errors.WithStack(context.Canceled)) } } diff --git a/solver/internal/pipe/pipe_test.go b/solver/internal/pipe/pipe_test.go index 54b59af350661..1e233caecf162 100644 --- a/solver/internal/pipe/pipe_test.go +++ b/solver/internal/pipe/pipe_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -14,7 +15,7 @@ func TestPipe(t *testing.T) { f := func(ctx context.Context) (interface{}, error) { select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, context.Cause(ctx) case <-runCh: return "res0", nil } @@ -56,7 +57,7 @@ func TestPipeCancel(t *testing.T) { f := func(ctx context.Context) (interface{}, error) { select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, context.Cause(ctx) case <-runCh: return "res0", nil } @@ -88,5 +89,5 @@ func TestPipeCancel(t *testing.T) { require.Equal(t, st.Completed, true) require.Equal(t, st.Canceled, true) require.Error(t, st.Err) - require.Equal(t, st.Err, context.Canceled) + require.True(t, errors.Is(st.Err, context.Canceled)) } diff --git a/solver/jobs.go b/solver/jobs.go index 6f908b56f891e..b0678cae2a905 100644 --- a/solver/jobs.go +++ b/solver/jobs.go @@ -255,7 +255,7 @@ type Job struct { startedTime time.Time completedTime time.Time - progressCloser func() + progressCloser func(error) SessionID string uniqueID string // unique ID is used for provenance. We use a different field that client can't control } @@ -589,7 +589,7 @@ func (j *Job) walkProvenance(ctx context.Context, e Edge, f func(ProvenanceProvi } func (j *Job) CloseProgress() { - j.progressCloser() + j.progressCloser(errors.WithStack(context.Canceled)) j.pw.Close() } @@ -790,7 +790,7 @@ func (s *sharedOp) CalcSlowCache(ctx context.Context, index Index, p PreprocessF if errdefs.IsCanceled(ctx, err) { complete = false releaseError(err) - err = errors.Wrap(ctx.Err(), err.Error()) + err = errors.Wrap(context.Cause(ctx), err.Error()) } default: } @@ -856,7 +856,7 @@ func (s *sharedOp) CacheMap(ctx context.Context, index int) (resp *cacheMapResp, if errdefs.IsCanceled(ctx, err) { complete = false releaseError(err) - err = errors.Wrap(ctx.Err(), err.Error()) + err = errors.Wrap(context.Cause(ctx), err.Error()) } default: } @@ -935,7 +935,7 @@ func (s *sharedOp) Exec(ctx context.Context, inputs []Result) (outputs []Result, if errdefs.IsCanceled(ctx, err) { complete = false releaseError(err) - err = errors.Wrap(ctx.Err(), err.Error()) + err = errors.Wrap(context.Cause(ctx), err.Error()) } default: } diff --git a/solver/llbsolver/history.go b/solver/llbsolver/history.go index d055a7bf54f08..e3a69f7dcd88d 100644 --- a/solver/llbsolver/history.go +++ b/solver/llbsolver/history.go @@ -835,7 +835,7 @@ func (h *HistoryQueue) Listen(ctx context.Context, req *controlapi.BuildHistoryR for { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case e := <-sub.ch: if req.Ref != "" && req.Ref != e.Record.Ref { continue diff --git a/solver/llbsolver/mounts/mount.go b/solver/llbsolver/mounts/mount.go index 67eac20d727bf..9d6011dcef89e 100644 --- a/solver/llbsolver/mounts/mount.go +++ b/solver/llbsolver/mounts/mount.go @@ -136,7 +136,7 @@ func (g *cacheRefGetter) getRefCacheDirNoCache(ctx context.Context, key string, select { case <-ctx.Done(): cacheRefsLocker.Lock(key) - return nil, ctx.Err() + return nil, context.Cause(ctx) case <-time.After(100 * time.Millisecond): cacheRefsLocker.Lock(key) } @@ -199,7 +199,7 @@ type sshMountInstance struct { } func (sm *sshMountInstance) Mount() ([]mount.Mount, func() error, error) { - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancelCause(context.TODO()) uid := int(sm.sm.mount.SSHOpt.Uid) gid := int(sm.sm.mount.SSHOpt.Gid) @@ -210,7 +210,7 @@ func (sm *sshMountInstance) Mount() ([]mount.Mount, func() error, error) { GID: gid, }) if err != nil { - cancel() + cancel(err) return nil, nil, err } uid = identity.UID @@ -224,7 +224,7 @@ func (sm *sshMountInstance) Mount() ([]mount.Mount, func() error, error) { Mode: int(sm.sm.mount.SSHOpt.Mode & 0777), }) if err != nil { - cancel() + cancel(err) return nil, nil, err } release := func() error { @@ -232,7 +232,7 @@ func (sm *sshMountInstance) Mount() ([]mount.Mount, func() error, error) { if cleanup != nil { err = cleanup() } - cancel() + cancel(err) return err } diff --git a/solver/llbsolver/solver.go b/solver/llbsolver/solver.go index 9295e08c63720..b912006738079 100644 --- a/solver/llbsolver/solver.go +++ b/solver/llbsolver/solver.go @@ -484,7 +484,7 @@ func (s *Solver) Solve(ctx context.Context, id string, sessionID string, req fro case <-fwd.Done(): res, err = fwd.Result() case <-ctx.Done(): - err = ctx.Err() + err = context.Cause(ctx) } if err != nil { return nil, err diff --git a/solver/llbsolver/vertex.go b/solver/llbsolver/vertex.go index bd3cb30db03f0..d42045195182a 100644 --- a/solver/llbsolver/vertex.go +++ b/solver/llbsolver/vertex.go @@ -206,8 +206,10 @@ func recomputeDigests(ctx context.Context, all map[digest.Digest]*pb.Op, visited var mutated bool for _, input := range op.Inputs { - if ctx.Err() != nil { - return "", ctx.Err() + select { + case <-ctx.Done(): + return "", context.Cause(ctx) + default: } iDgst, err := recomputeDigests(ctx, all, visited, input.Digest) diff --git a/solver/progress.go b/solver/progress.go index 3fb954f867c42..92e2c6cb009fe 100644 --- a/solver/progress.go +++ b/solver/progress.go @@ -91,7 +91,7 @@ func (j *Job) Status(ctx context.Context, ch chan *client.SolveStatus) error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case ch <- ss: } } diff --git a/solver/scheduler.go b/solver/scheduler.go index 8b2fd8bfa3a91..cee36672640d3 100644 --- a/solver/scheduler.go +++ b/solver/scheduler.go @@ -245,8 +245,8 @@ func (s *scheduler) build(ctx context.Context, edge Edge) (CachedResult, error) } s.mu.Unlock() - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) go func() { <-ctx.Done() diff --git a/solver/scheduler_test.go b/solver/scheduler_test.go index c4647f8178b0b..79497913076d6 100644 --- a/solver/scheduler_test.go +++ b/solver/scheduler_test.go @@ -488,13 +488,13 @@ func TestSingleCancelCache(t *testing.T) { } }() - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) g0 := Edge{ Vertex: vtx(vtxOpt{ name: "v0", cachePreFunc: func(ctx context.Context) error { - cancel() + cancel(errors.WithStack(context.Canceled)) <-ctx.Done() return nil // error should still come from context }, @@ -530,13 +530,13 @@ func TestSingleCancelExec(t *testing.T) { } }() - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) g1 := Edge{ Vertex: vtx(vtxOpt{ name: "v2", execPreFunc: func(ctx context.Context) error { - cancel() + cancel(errors.WithStack(context.Canceled)) <-ctx.Done() return nil // error should still come from context }, @@ -580,8 +580,8 @@ func TestSingleCancelParallel(t *testing.T) { } }() - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) g := Edge{ Vertex: vtx(vtxOpt{ @@ -590,7 +590,7 @@ func TestSingleCancelParallel(t *testing.T) { cachePreFunc: func(ctx context.Context) error { close(firstReady) time.Sleep(200 * time.Millisecond) - cancel() + cancel(errors.WithStack(context.Canceled)) <-firstErrored return nil }, @@ -3452,13 +3452,13 @@ func (v *vertex) cacheMap(ctx context.Context) error { } select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) default: } select { case <-time.After(v.opt.cacheDelay): case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) } return nil } @@ -3489,13 +3489,13 @@ func (v *vertex) exec(ctx context.Context, inputs []Result) error { } select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) default: } select { case <-time.After(v.opt.execDelay): case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) } return nil } diff --git a/source/git/source_test.go b/source/git/source_test.go index 6712340a25d5a..1e9bb71c31c42 100644 --- a/source/git/source_test.go +++ b/source/git/source_test.go @@ -33,6 +33,7 @@ import ( "github.com/moby/buildkit/util/progress" "github.com/moby/buildkit/util/progress/logs" "github.com/moby/buildkit/util/winlayers" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" bolt "go.etcd.io/bbolt" @@ -663,7 +664,7 @@ func logProgressStreams(ctx context.Context, t *testing.T) context.Context { pr, ctx, cancel := progress.NewContext(ctx) done := make(chan struct{}) t.Cleanup(func() { - cancel() + cancel(errors.WithStack(context.Canceled)) <-done }) go func() { diff --git a/util/appcontext/appcontext.go b/util/appcontext/appcontext.go index f9cf0ba8fb60b..3d6626535b9a4 100644 --- a/util/appcontext/appcontext.go +++ b/util/appcontext/appcontext.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/moby/buildkit/util/bklog" + "github.com/pkg/errors" ) var appContextCache context.Context @@ -27,16 +28,17 @@ func Context() context.Context { ctx = f(ctx) } - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) appContextCache = ctx go func() { for { <-signals - cancel() retries++ + err := errors.Errorf("got %d SIGTERM/SIGINTs, forcing shutdown", retries) + cancel(err) if retries >= exitLimit { - bklog.G(ctx).Errorf("got %d SIGTERM/SIGINTs, forcing shutdown", retries) + bklog.G(ctx).Errorf(err.Error()) os.Exit(1) } } diff --git a/util/flightcontrol/flightcontrol.go b/util/flightcontrol/flightcontrol.go index 82ed25205fe42..42cb23678f1bd 100644 --- a/util/flightcontrol/flightcontrol.go +++ b/util/flightcontrol/flightcontrol.go @@ -90,7 +90,7 @@ type call[T any] struct { fn func(ctx context.Context) (T, error) once sync.Once - closeProgressWriter func() + closeProgressWriter func(error) progressState *progressState progressCtx context.Context } @@ -115,9 +115,9 @@ func newCall[T any](fn func(ctx context.Context) (T, error)) *call[T] { } func (c *call[T]) run() { - defer c.closeProgressWriter() - ctx, cancel := context.WithCancel(c.ctx) - defer cancel() + defer c.closeProgressWriter(errors.WithStack(context.Canceled)) + ctx, cancel := context.WithCancelCause(c.ctx) + defer cancel(errors.WithStack(context.Canceled)) v, err := c.fn(ctx) c.mu.Lock() c.result = v @@ -155,8 +155,8 @@ func (c *call[T]) wait(ctx context.Context) (v T, err error) { c.progressState.add(pw) } - ctx, cancel := context.WithCancel(ctx) - defer cancel() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.WithStack(context.Canceled)) c.ctxs = append(c.ctxs, ctx) @@ -175,7 +175,7 @@ func (c *call[T]) wait(ctx context.Context) (v T, err error) { if ok { c.progressState.close(pw) } - return empty, ctx.Err() + return empty, context.Cause(ctx) case <-c.ready: return c.result, c.err // shared not implemented yet } @@ -262,7 +262,9 @@ func (sc *sharedContext[T]) checkDone() bool { for _, ctx := range sc.ctxs { select { case <-ctx.Done(): - err = ctx.Err() + // Cause can't be used here because this error is returned for Err() in custom context + // implementation and unfortunately stdlib does not allow defining Cause() for custom contexts + err = ctx.Err() //nolint: forbidigo default: sc.mu.Unlock() return false diff --git a/util/flightcontrol/flightcontrol_test.go b/util/flightcontrol/flightcontrol_test.go index 3c8aebdfd8aab..2c7063808b0ce 100644 --- a/util/flightcontrol/flightcontrol_test.go +++ b/util/flightcontrol/flightcontrol_test.go @@ -50,7 +50,7 @@ func TestCancelOne(t *testing.T) { var r1, r2 string var counter int64 f := testFunc(100*time.Millisecond, "bar", &counter) - ctx2, cancel := context.WithCancel(ctx) + ctx2, cancel := context.WithCancelCause(ctx) eg.Go(func() error { ret1, err := g.Do(ctx2, "foo", f) assert.Error(t, err) @@ -71,9 +71,9 @@ func TestCancelOne(t *testing.T) { eg.Go(func() error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-time.After(30 * time.Millisecond): - cancel() + cancel(errors.WithStack(context.Canceled)) return nil } }) @@ -88,7 +88,7 @@ func TestCancelRace(t *testing.T) { // t.Parallel() // disabled for better timing consistency. works with parallel as well g := &Group[struct{}]{} - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancelCause(context.Background()) kick := make(chan struct{}) wait := make(chan struct{}) @@ -118,7 +118,7 @@ func TestCancelRace(t *testing.T) { time.Sleep(50 * time.Millisecond) select { case <-done: - return struct{}{}, ctx.Err() + return struct{}{}, context.Cause(ctx) case <-time.After(200 * time.Millisecond): } return struct{}{}, nil @@ -127,7 +127,7 @@ func TestCancelRace(t *testing.T) { go func() { defer close(wait) <-kick - cancel() + cancel(errors.WithStack(context.Canceled)) time.Sleep(5 * time.Millisecond) _, err := g.Do(context.Background(), "foo", f) require.NoError(t, err) @@ -146,8 +146,8 @@ func TestCancelBoth(t *testing.T) { var r1, r2 string var counter int64 f := testFunc(200*time.Millisecond, "bar", &counter) - ctx2, cancel2 := context.WithCancel(ctx) - ctx3, cancel3 := context.WithCancel(ctx) + ctx2, cancel2 := context.WithCancelCause(ctx) + ctx3, cancel3 := context.WithCancelCause(ctx) eg.Go(func() error { ret1, err := g.Do(ctx2, "foo", f) assert.Error(t, err) @@ -169,18 +169,18 @@ func TestCancelBoth(t *testing.T) { eg.Go(func() error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-time.After(20 * time.Millisecond): - cancel2() + cancel2(errors.WithStack(context.Canceled)) return nil } }) eg.Go(func() error { select { case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) case <-time.After(50 * time.Millisecond): - cancel3() + cancel3(errors.WithStack(context.Canceled)) return nil } }) @@ -228,7 +228,7 @@ func testFunc(wait time.Duration, ret string, counter *int64) func(ctx context.C atomic.AddInt64(counter, 1) select { case <-ctx.Done(): - return "", ctx.Err() + return "", context.Cause(ctx) case <-time.After(wait): return ret, nil } diff --git a/util/overlay/overlay_linux.go b/util/overlay/overlay_linux.go index 62179f9ce825b..8c018dfcef1f8 100644 --- a/util/overlay/overlay_linux.go +++ b/util/overlay/overlay_linux.go @@ -161,8 +161,10 @@ func Changes(ctx context.Context, changeFn fs.ChangeFunc, upperdir, upperdirView if err != nil { return err } - if ctx.Err() != nil { - return ctx.Err() + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: } // Rebase path diff --git a/util/progress/multireader.go b/util/progress/multireader.go index b0d92dde8f252..d6d3fb7c79a9a 100644 --- a/util/progress/multireader.go +++ b/util/progress/multireader.go @@ -11,14 +11,15 @@ type MultiReader struct { main Reader initialized bool done chan struct{} - writers map[*progressWriter]func() + doneCause error + writers map[*progressWriter]func(error) sent []*Progress } func NewMultiReader(pr Reader) *MultiReader { mr := &MultiReader{ main: pr, - writers: make(map[*progressWriter]func()), + writers: make(map[*progressWriter]func(error)), done: make(chan struct{}), } return mr @@ -46,9 +47,9 @@ func (mr *MultiReader) Reader(ctx context.Context) Reader { go func() { if isBehind { - close := func() { + close := func(err error) { w.Close() - closeWriter() + closeWriter(err) } i := 0 for { @@ -58,11 +59,11 @@ func (mr *MultiReader) Reader(ctx context.Context) Reader { if count == 0 { select { case <-ctx.Done(): - close() + close(context.Cause(ctx)) mr.mu.Unlock() return case <-mr.done: - close() + close(mr.doneCause) mr.mu.Unlock() return default: @@ -77,7 +78,7 @@ func (mr *MultiReader) Reader(ctx context.Context) Reader { if i%100 == 0 { select { case <-ctx.Done(): - close() + close(context.Cause(ctx)) return default: } @@ -110,10 +111,12 @@ func (mr *MultiReader) handle() error { if err != nil { if err == io.EOF { mr.mu.Lock() + cancelErr := context.Canceled for w, c := range mr.writers { w.Close() - c() + c(cancelErr) } + mr.doneCause = cancelErr close(mr.done) mr.mu.Unlock() return nil diff --git a/util/progress/progress.go b/util/progress/progress.go index fbbb22de071ee..fb193113a766c 100644 --- a/util/progress/progress.go +++ b/util/progress/progress.go @@ -56,7 +56,7 @@ type WriterOption func(Writer) // NewContext returns a new context and a progress reader that captures all // progress items writtern to this context. Last returned parameter is a closer // function to signal that no new writes will happen to this context. -func NewContext(ctx context.Context) (Reader, context.Context, func()) { +func NewContext(ctx context.Context) (Reader, context.Context, func(error)) { pr, pw, cancel := pipe() ctx = WithProgress(ctx, pw) return pr, ctx, cancel @@ -141,7 +141,7 @@ func (pr *progressReader) Read(ctx context.Context) ([]*Progress, error) { select { case <-ctx.Done(): pr.mu.Unlock() - return nil, ctx.Err() + return nil, context.Cause(ctx) default: } dmap := pr.dirty @@ -185,8 +185,8 @@ func (pr *progressReader) append(pw *progressWriter) { } } -func pipe() (*progressReader, *progressWriter, func()) { - ctx, cancel := context.WithCancel(context.Background()) +func pipe() (*progressReader, *progressWriter, func(error)) { + ctx, cancel := context.WithCancelCause(context.Background()) pr := &progressReader{ ctx: ctx, writers: make(map[*progressWriter]struct{}), diff --git a/util/progress/progress_test.go b/util/progress/progress_test.go index 3eae80d0c90bf..a73115082de73 100644 --- a/util/progress/progress_test.go +++ b/util/progress/progress_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "golang.org/x/sync/errgroup" ) @@ -31,7 +32,7 @@ func TestProgress(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 15, s) - cancelProgress() + cancelProgress(errors.WithStack(context.Canceled)) err = eg.Wait() assert.NoError(t, err) @@ -56,7 +57,7 @@ func TestProgressNested(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 6, s) - cancelProgress() + cancelProgress(errors.WithStack(context.Canceled)) err = eg.Wait() assert.NoError(t, err) @@ -74,7 +75,7 @@ func calc(ctx context.Context, total int, name string) (int, error) { for i := 1; i <= total; i++ { select { case <-ctx.Done(): - return 0, ctx.Err() + return 0, context.Cause(ctx) case <-time.After(10 * time.Millisecond): } if i == total { diff --git a/util/progress/progressui/display.go b/util/progress/progressui/display.go index 01548229036aa..722fb77c64143 100644 --- a/util/progress/progressui/display.go +++ b/util/progress/progressui/display.go @@ -99,7 +99,7 @@ func (d Display) UpdateFrom(ctx context.Context, ch chan *client.SolveStatus) ([ for { select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, context.Cause(ctx) case <-ticker.C: d.disp.refresh() case ss, ok := <-ch: diff --git a/util/pull/pullprogress/progress.go b/util/pull/pullprogress/progress.go index 5ae047dbf549b..479a65016037e 100644 --- a/util/pull/pullprogress/progress.go +++ b/util/pull/pullprogress/progress.go @@ -31,7 +31,7 @@ func (p *ProviderWithProgress) ReaderAt(ctx context.Context, desc ocispecs.Descr return nil, err } - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) doneCh := make(chan struct{}) go trackProgress(ctx, desc, p.Manager, doneCh) return readerAtWithCancel{ReaderAt: ra, cancel: cancel, doneCh: doneCh, logger: bklog.G(ctx)}, nil @@ -39,13 +39,13 @@ func (p *ProviderWithProgress) ReaderAt(ctx context.Context, desc ocispecs.Descr type readerAtWithCancel struct { content.ReaderAt - cancel func() + cancel func(error) doneCh <-chan struct{} logger *logrus.Entry } func (ra readerAtWithCancel) Close() error { - ra.cancel() + ra.cancel(errors.WithStack(context.Canceled)) select { case <-ra.doneCh: case <-time.After(time.Second): @@ -65,7 +65,7 @@ func (f *FetcherWithProgress) Fetch(ctx context.Context, desc ocispecs.Descripto return nil, err } - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) doneCh := make(chan struct{}) go trackProgress(ctx, desc, f.Manager, doneCh) return readerWithCancel{ReadCloser: rc, cancel: cancel, doneCh: doneCh, logger: bklog.G(ctx)}, nil @@ -73,13 +73,13 @@ func (f *FetcherWithProgress) Fetch(ctx context.Context, desc ocispecs.Descripto type readerWithCancel struct { io.ReadCloser - cancel func() + cancel func(error) doneCh <-chan struct{} logger *logrus.Entry } func (r readerWithCancel) Close() error { - r.cancel() + r.cancel(errors.WithStack(context.Canceled)) select { case <-r.doneCh: case <-time.After(time.Second): diff --git a/util/staticfs/merge.go b/util/staticfs/merge.go index 0ff03f504861f..c5a582a5ffe5a 100644 --- a/util/staticfs/merge.go +++ b/util/staticfs/merge.go @@ -49,7 +49,7 @@ func (mfs *MergeFS) Walk(ctx context.Context, target string, fn fs.WalkDirFunc) case ch1 <- &record{path: path, entry: entry, err: err}: case <-ctx.Done(): } - return ctx.Err() + return context.Cause(ctx) }) }) eg.Go(func() error { @@ -59,7 +59,7 @@ func (mfs *MergeFS) Walk(ctx context.Context, target string, fn fs.WalkDirFunc) case ch2 <- &record{path: path, entry: entry, err: err}: case <-ctx.Done(): } - return ctx.Err() + return context.Cause(ctx) }) }) diff --git a/util/tracing/otlptracegrpc/client.go b/util/tracing/otlptracegrpc/client.go index e8d13301f3d53..3c05f43940473 100644 --- a/util/tracing/otlptracegrpc/client.go +++ b/util/tracing/otlptracegrpc/client.go @@ -70,7 +70,7 @@ func (c *client) UploadTraces(ctx context.Context, protoSpans []*tracepb.Resourc } ctx, cancel := c.connection.ContextWithStop(ctx) - defer cancel() + defer cancel(errors.WithStack(context.Canceled)) ctx, tCancel := context.WithTimeout(ctx, 30*time.Second) defer tCancel() diff --git a/util/tracing/otlptracegrpc/connection.go b/util/tracing/otlptracegrpc/connection.go index dbb0fcd39f476..6b52d3594e93a 100644 --- a/util/tracing/otlptracegrpc/connection.go +++ b/util/tracing/otlptracegrpc/connection.go @@ -22,6 +22,7 @@ import ( "time" "unsafe" + "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -185,7 +186,7 @@ func (c *Connection) Shutdown(ctx context.Context) error { select { case <-c.backgroundConnectionDoneCh: case <-ctx.Done(): - return ctx.Err() + return context.Cause(ctx) } c.mu.Lock() @@ -200,17 +201,17 @@ func (c *Connection) Shutdown(ctx context.Context) error { return nil } -func (c *Connection) ContextWithStop(ctx context.Context) (context.Context, context.CancelFunc) { +func (c *Connection) ContextWithStop(ctx context.Context) (context.Context, context.CancelCauseFunc) { // Unify the parent context Done signal with the Connection's // stop channel. - ctx, cancel := context.WithCancel(ctx) - go func(ctx context.Context, cancel context.CancelFunc) { + ctx, cancel := context.WithCancelCause(ctx) + go func(ctx context.Context, cancel context.CancelCauseFunc) { select { case <-ctx.Done(): // Nothing to do, either cancelled or deadline // happened. case <-c.stopCh: - cancel() + cancel(errors.WithStack(context.Canceled)) } }(ctx, cancel) return ctx, cancel diff --git a/worker/tests/common.go b/worker/tests/common.go index 458f41a6d73cf..0a6d8be4c84f2 100644 --- a/worker/tests/common.go +++ b/worker/tests/common.go @@ -38,7 +38,7 @@ func NewCtx(s string) context.Context { func TestWorkerExec(t *testing.T, w *base.Worker) { ctx := NewCtx("buildkit-test") - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) sm, err := session.NewManager() require.NoError(t, err) @@ -151,7 +151,7 @@ func TestWorkerExec(t *testing.T, w *base.Worker) { require.Empty(t, stderr.String()) // stop pid1 - cancel() + cancel(errors.WithStack(context.Canceled)) err = eg.Wait() // we expect pid1 to get canceled after we test the exec @@ -248,8 +248,8 @@ func TestWorkerCancel(t *testing.T, w *base.Worker) { started := make(chan struct{}) - pid1Ctx, pid1Cancel := context.WithCancel(ctx) - defer pid1Cancel() + pid1Ctx, pid1Cancel := context.WithCancelCause(ctx) + defer pid1Cancel(errors.WithStack(context.Canceled)) var ( pid1Err, pid2Err error @@ -273,8 +273,8 @@ func TestWorkerCancel(t *testing.T, w *base.Worker) { t.Error("Unexpected timeout waiting for pid1 to start") } - pid2Ctx, pid2Cancel := context.WithCancel(ctx) - defer pid2Cancel() + pid2Ctx, pid2Cancel := context.WithCancelCause(ctx) + defer pid2Cancel(errors.WithStack(context.Canceled)) started = make(chan struct{}) @@ -299,11 +299,11 @@ func TestWorkerCancel(t *testing.T, w *base.Worker) { t.Error("Unexpected timeout waiting for pid2 to start") } - pid2Cancel() + pid2Cancel(errors.WithStack(context.Canceled)) <-pid2Done require.Contains(t, pid2Err.Error(), "exit code: 137", "pid2 exits with sigkill") - pid1Cancel() + pid1Cancel(errors.WithStack(context.Canceled)) <-pid1Done require.Contains(t, pid1Err.Error(), "exit code: 137", "pid1 exits with sigkill") }