Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace context.WithCancel with WithCancelCause #4457

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ linters-settings:
forbid:
- '^fmt\.Errorf(# use errors\.Errorf instead)?$'
- '^logrus\.(Trace|Debug|Info|Warn|Warning|Error|Fatal)(f|ln)?(# use bklog\.G or bklog\.L instead of logrus directly)?$'
- '^context\.WithCancel(# use context\.WithCancelCause instead)?$'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Should we require context.WithTimeoutCause and context.WithDeadlineCause in this PR also? (I don't think we use WithDeadline anywhere, but should probably forbid it for the future)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, marked in TODO comment

Copy link
Member Author

@tonistiigi tonistiigi Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm not sure if anything can be done for this. Looks like WithTimeoutCause/WithDeadlineCause have a completely different signature and cause is added as a parameter instead of returning func(error). This isn't really useful as it is basically just doing ctx.WithValue("cause", cause) and doesn't detect when the cancellation actually happened.

Note that if context actually reached timeout then this is somewhat expected as it will happen somewhere in stdlib but if CancelFunc gets called then with this stdlib function it does not allow detecting the location.

Maybe for the defer() case still something can be done about this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess maybe something like this could work?

cause := &timeoutCause{}
ctx, cancel := context.WithTimeoutCause(ctx, ..., cause.Init())
cancelErr = cause.WithCancel(cancel)


cancelErr(errors.WithStack(context.DeadlineExceeded))


type timeoutCause {
  error
}

func (t *timeoutCause) Init() {
  // todo: atomic
  // store the context creation stack. this will be shown when timeout is reached and runtime cancels
  t.error = errors.WithStack(context.DeadlineExceeded)
}

func (t *timeoutCause) WithCancel(c func()) func(error) {
  return func(e error) {
    // todo: atomic
    // store the context cancellation stack. this will be shown when cancel is called manually
    t.error = errors.WithStack(o)
  }
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the go issue golang/go#56661 there is also another workaround mentioned.

ctx, cancel := context.WithCancelCause(context.Background())
ctx, _ = context.WithTimeoutCause(ctx, 1time.Second, tooSlow)

what is simpler but I'm afraid this will cause a linter error for leaking context that needs to be disabled all the time. Really don't understand why WithTimeoutCause just does not return CancelCauseFunc.

Copy link
Collaborator

@coryb coryb Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm for WithTimeoutCause I use something like:

context.WithTimeoutCause(ctx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))

So we can get the stack for which specific timeout what triggered.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will give you the stacktrace to place that created the context. But for example if you have:

{
ctx, cancel := context.WithTimeoutCause(ctx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))
defer cancel()

go func() {
  // leaky goroutine
  err := foo(ctx)
  // cancellation error because context was discarded but no stacktrace to where it happened
}()

return nil
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah, maybe add a helper function to wrap timeouts, and add that to the forbidigo rule?

func withTimeout(ctx context.Context, d time.Duration) (context.Context, func(error)) {
    ctx, cancel := context.WithCancelCause(ctx)
    ctx, _ = context.WithTimeoutCause(ctx, d, errors.WithStack(context.DeadlineExceeded))
    return ctx, cancel
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I turns out the linter only has this leak detection for context.WithTimeout but nobody has yet added a rule for context.WithTimeoutCause yet. So this is a problem for someone in the future trying to update the linter 😉

- '^context\.WithTimeout(# use context\.WithTimeoutCause instead)?$'
- '^context\.WithDeadline(# use context\.WithDeadline instead)?$'
- '^ctx\.Err(# use context\.Cause instead)?$'
importas:
alias:
- pkg: "github.com/opencontainers/image-spec/specs-go/v1"
Expand Down
2 changes: 1 addition & 1 deletion cache/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ func (cm *cacheManager) prune(ctx context.Context, ch chan client.UsageInfo, opt

select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
default:
return cm.prune(ctx, ch, opt)
}
Expand Down
10 changes: 6 additions & 4 deletions cache/remotecache/azblob/exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ func (ce *exporter) uploadManifest(ctx context.Context, manifestKey string, read
return errors.Wrap(err, "error creating container client")
}

ctx, cnclFn := context.WithTimeout(ctx, time.Minute*5)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))

_, err = blobClient.Upload(ctx, reader, &azblob.BlockBlobUploadOptions{})
if err != nil {
Expand All @@ -170,8 +171,9 @@ func (ce *exporter) uploadBlobIfNotExists(ctx context.Context, blobKey string, r
return errors.Wrap(err, "error creating container client")
}

uploadCtx, cnclFn := context.WithTimeout(ctx, time.Minute*5)
defer cnclFn()
uploadCtx, cnclFn := context.WithCancelCause(ctx)
uploadCtx, _ = context.WithTimeoutCause(uploadCtx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))

// Only upload if the blob doesn't exist
eTagAny := azblob.ETagAny
Expand Down
15 changes: 9 additions & 6 deletions cache/remotecache/azblob/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ func createContainerClient(ctx context.Context, config *Config) (*azblob.Contain
}
}

ctx, cnclFn := context.WithTimeout(ctx, time.Second*60)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Second*60, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))

containerClient, err := serviceClient.NewContainerClient(config.Container)
if err != nil {
Expand All @@ -148,8 +149,9 @@ func createContainerClient(ctx context.Context, config *Config) (*azblob.Contain

var se *azblob.StorageError
if errors.As(err, &se) && se.ErrorCode == azblob.StorageErrorCodeContainerNotFound {
ctx, cnclFn := context.WithTimeout(ctx, time.Minute*5)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Minute*5, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))
_, err := containerClient.Create(ctx, &azblob.ContainerCreateOptions{})
if err != nil {
return nil, errors.Wrapf(err, "failed to create cache container %s", config.Container)
Expand Down Expand Up @@ -177,8 +179,9 @@ func blobExists(ctx context.Context, containerClient *azblob.ContainerClient, bl
return false, errors.Wrap(err, "error creating blob client")
}

ctx, cnclFn := context.WithTimeout(ctx, time.Second*60)
defer cnclFn()
ctx, cnclFn := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, time.Second*60, errors.WithStack(context.DeadlineExceeded))
defer cnclFn(errors.WithStack(context.Canceled))
_, err = blobClient.GetProperties(ctx, &azblob.BlobGetPropertiesOptions{})
if err == nil {
return true, nil
Expand Down
5 changes: 3 additions & 2 deletions cache/remotecache/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ func getContentStore(ctx context.Context, sm *session.Manager, g session.Group,
if sessionID == "" {
return nil, errors.New("local cache exporter/importer requires session")
}
timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
timeoutCtx, cancel := context.WithCancelCause(context.Background())
timeoutCtx, _ = context.WithTimeoutCause(timeoutCtx, 5*time.Second, errors.WithStack(context.DeadlineExceeded))
defer cancel(errors.WithStack(context.Canceled))

caller, err := sm.Get(timeoutCtx, sessionID, false)
if err != nil {
Expand Down
16 changes: 8 additions & 8 deletions client/build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ func testClientGatewayContainerPID1Tty(t *testing.T, sb integration.Sandbox) {
output := bytes.NewBuffer(nil)

b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()

st := llb.Image("busybox:latest")
Expand Down Expand Up @@ -1015,7 +1015,7 @@ func testClientGatewayContainerCancelPID1Tty(t *testing.T, sb integration.Sandbo
output := bytes.NewBuffer(nil)

b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
ctx, cancel := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer cancel()

st := llb.Image("busybox:latest")
Expand Down Expand Up @@ -1141,7 +1141,7 @@ func testClientGatewayContainerExecTty(t *testing.T, sb integration.Sandbox) {
inputR, inputW := io.Pipe()
output := bytes.NewBuffer(nil)
b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()
st := llb.Image("busybox:latest")

Expand Down Expand Up @@ -1233,7 +1233,7 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo
inputR, inputW := io.Pipe()
output := bytes.NewBuffer(nil)
b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()
st := llb.Image("busybox:latest")

Expand Down Expand Up @@ -1266,8 +1266,8 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo
defer pid1.Wait()
defer ctr.Release(ctx)

execCtx, cancel := context.WithCancel(ctx)
defer cancel()
execCtx, cancel := context.WithCancelCause(ctx)
defer cancel(errors.WithStack(context.Canceled))

prompt := newTestPrompt(execCtx, t, inputW, output)
pid2, err := ctr.Start(execCtx, client.StartRequest{
Expand All @@ -1281,7 +1281,7 @@ func testClientGatewayContainerCancelExecTty(t *testing.T, sb integration.Sandbo
require.NoError(t, err)

prompt.SendExpect("echo hi", "hi")
cancel()
cancel(errors.WithStack(context.Canceled))

err = pid2.Wait()
require.ErrorIs(t, err, context.Canceled)
Expand Down Expand Up @@ -2132,7 +2132,7 @@ func testClientGatewayContainerSignal(t *testing.T, sb integration.Sandbox) {
product := "buildkit_test"

b := func(ctx context.Context, c client.Client) (*client.Result, error) {
ctx, timeout := context.WithTimeout(ctx, 10*time.Second)
ctx, timeout := context.WithTimeoutCause(ctx, 10*time.Second, nil)
defer timeout()

st := llb.Image("busybox:latest")
Expand Down
2 changes: 1 addition & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (c *Client) Wait(ctx context.Context) error {

select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case <-time.After(time.Second):
}
c.conn.ResetConnectBackoff()
Expand Down
6 changes: 3 additions & 3 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7407,8 +7407,8 @@ func testInvalidExporter(t *testing.T, sb integration.Sandbox) {

// moby/buildkit#492
func testParallelLocalBuilds(t *testing.T, sb integration.Sandbox) {
ctx, cancel := context.WithCancel(sb.Context())
defer cancel()
ctx, cancel := context.WithCancelCause(sb.Context())
defer cancel(errors.WithStack(context.Canceled))

c, err := New(ctx, sb.Address())
require.NoError(t, err)
Expand Down Expand Up @@ -9832,7 +9832,7 @@ func testLLBMountPerformance(t *testing.T, sb integration.Sandbox) {
def, err := st.Marshal(sb.Context())
require.NoError(t, err)

timeoutCtx, cancel := context.WithTimeout(sb.Context(), time.Minute)
timeoutCtx, cancel := context.WithTimeoutCause(sb.Context(), time.Minute, nil)
defer cancel()
_, err = c.Solve(timeoutCtx, def, SolveOpt{}, nil)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion client/llb/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (as *asyncState) Do(ctx context.Context, c *Constraints) error {
if err != nil {
select {
case <-ctx.Done():
if errors.Is(err, ctx.Err()) {
if errors.Is(err, context.Cause(ctx)) {
return res, err
}
default:
Expand Down
12 changes: 6 additions & 6 deletions client/solve.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG
}
eg, ctx := errgroup.WithContext(ctx)

statusContext, cancelStatus := context.WithCancel(context.Background())
defer cancelStatus()
statusContext, cancelStatus := context.WithCancelCause(context.Background())
defer cancelStatus(errors.WithStack(context.Canceled))

if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
statusContext = trace.ContextWithSpan(statusContext, span)
Expand Down Expand Up @@ -230,16 +230,16 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG
frontendAttrs[k] = v
}

solveCtx, cancelSolve := context.WithCancel(ctx)
solveCtx, cancelSolve := context.WithCancelCause(ctx)
var res *SolveResponse
eg.Go(func() error {
ctx := solveCtx
defer cancelSolve()
defer cancelSolve(errors.WithStack(context.Canceled))

defer func() { // make sure the Status ends cleanly on build errors
go func() {
<-time.After(3 * time.Second)
cancelStatus()
cancelStatus(errors.WithStack(context.Canceled))
}()
if !opt.SessionPreInitialized {
bklog.G(ctx).Debugf("stopping session")
Expand Down Expand Up @@ -298,7 +298,7 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG
select {
case <-solveCtx.Done():
case <-time.After(5 * time.Second):
cancelSolve()
cancelSolve(errors.WithStack(context.Canceled))
}

return err
Expand Down
5 changes: 3 additions & 2 deletions cmd/buildctl/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ func ResolveClient(c *cli.Context) (*client.Client, error) {

timeout := time.Duration(c.GlobalInt("timeout"))
if timeout > 0 {
ctx2, cancel := context.WithTimeout(ctx, timeout*time.Second)
ctx2, cancel := context.WithCancelCause(ctx)
ctx2, _ = context.WithTimeoutCause(ctx2, timeout*time.Second, errors.WithStack(context.DeadlineExceeded))
ctx = ctx2
defer cancel()
defer cancel(errors.WithStack(context.Canceled))
}

cl, err := client.New(ctx, c.GlobalString("addr"), opts...)
Expand Down
14 changes: 7 additions & 7 deletions cmd/buildkitd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ func main() {
if os.Geteuid() > 0 {
return errors.New("rootless mode requires to be executed as the mapped root in a user namespace; you may use RootlessKit for setting up the namespace")
}
ctx, cancel := context.WithCancel(appcontext.Context())
defer cancel()
ctx, cancel := context.WithCancelCause(appcontext.Context())
defer cancel(errors.WithStack(context.Canceled))

cfg, err := config.LoadFile(c.GlobalString("config"))
if err != nil {
Expand Down Expand Up @@ -344,9 +344,9 @@ func main() {
select {
case serverErr := <-errCh:
err = serverErr
cancel()
cancel(err)
case <-ctx.Done():
err = ctx.Err()
err = context.Cause(ctx)
}

bklog.G(ctx).Infof("stopping server")
Expand Down Expand Up @@ -634,14 +634,14 @@ func unaryInterceptor(globalCtx context.Context, tp trace.TracerProvider) grpc.U
withTrace := otelgrpc.UnaryServerInterceptor(otelgrpc.WithTracerProvider(tp), otelgrpc.WithPropagators(propagators))

return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(errors.WithStack(context.Canceled))

go func() {
select {
case <-ctx.Done():
case <-globalCtx.Done():
cancel()
cancel(context.Cause(globalCtx))
}
}()

Expand Down
4 changes: 2 additions & 2 deletions control/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,10 @@ func (c *Controller) Session(stream controlapi.Control_SessionServer) error {
conn, closeCh, opts := grpchijack.Hijack(stream)
defer conn.Close()

ctx, cancel := context.WithCancel(stream.Context())
ctx, cancel := context.WithCancelCause(stream.Context())
go func() {
<-closeCh
cancel()
cancel(errors.WithStack(context.Canceled))
}()

err := c.opt.SessionManager.HandleConn(ctx, conn, opts)
Expand Down
5 changes: 3 additions & 2 deletions control/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ func (gwf *GatewayForwarder) lookupForwarder(ctx context.Context) (gateway.LLBBr
return nil, errors.New("no buildid found in context")
}

ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
ctx, cancel := context.WithCancelCause(ctx)
ctx, _ = context.WithTimeoutCause(ctx, 3*time.Second, errors.WithStack(context.DeadlineExceeded))
defer cancel(errors.WithStack(context.Canceled))

go func() {
<-ctx.Done()
Expand Down
17 changes: 9 additions & 8 deletions executor/containerdexecutor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func (w *containerdExecutor) Exec(ctx context.Context, id string, process execut
}
select {
case <-ctx.Done():
return ctx.Err()
return context.Cause(ctx)
case err, ok := <-details.done:
if !ok || err == nil {
return errors.Errorf("container %s has stopped", id)
Expand Down Expand Up @@ -336,8 +336,8 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces

// handle signals (and resize) in separate go loop so it does not
// potentially block the container cancel/exit status loop below.
eventCtx, eventCancel := context.WithCancel(ctx)
defer eventCancel()
eventCtx, eventCancel := context.WithCancelCause(ctx)
defer eventCancel(errors.WithStack(context.Canceled))
go func() {
for {
select {
Expand Down Expand Up @@ -371,21 +371,22 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces
}
}()

var cancel func()
var cancel func(error)
var killCtxDone <-chan struct{}
ctxDone := ctx.Done()
for {
select {
case <-ctxDone:
ctxDone = nil
var killCtx context.Context
killCtx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
killCtx, cancel = context.WithCancelCause(context.Background())
killCtx, _ = context.WithTimeoutCause(killCtx, 10*time.Second, errors.WithStack(context.DeadlineExceeded))
killCtxDone = killCtx.Done()
p.Kill(killCtx, syscall.SIGKILL)
io.Cancel()
case status := <-statusCh:
if cancel != nil {
cancel()
cancel(errors.WithStack(context.Canceled))
}
trace.SpanFromContext(ctx).AddEvent(
"Container exited",
Expand All @@ -403,15 +404,15 @@ func (w *containerdExecutor) runProcess(ctx context.Context, p containerd.Proces
}
select {
case <-ctx.Done():
exitErr.Err = errors.Wrap(ctx.Err(), exitErr.Error())
exitErr.Err = errors.Wrap(context.Cause(ctx), exitErr.Error())
default:
}
return exitErr
}
return nil
case <-killCtxDone:
if cancel != nil {
cancel()
cancel(errors.WithStack(context.Canceled))
}
io.Cancel()
return errors.Errorf("failed to kill process on cancel")
Expand Down