From fef9154ea4c587679539c30b7c255fde084e965f Mon Sep 17 00:00:00 2001 From: David Sauer Date: Thu, 17 Feb 2022 17:22:02 +0100 Subject: [PATCH] stop 'kubectl port-forward' when the parent process (d8s) exits --- pkg/run.go | 53 +++++++++++++++++++++++++++++++++++------------------ pkg/up.go | 7 +++---- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/pkg/run.go b/pkg/run.go index 7b99853f..5920c2df 100644 --- a/pkg/run.go +++ b/pkg/run.go @@ -2,6 +2,7 @@ package d8s import ( "context" + "errors" "fmt" "log" "net" @@ -32,12 +33,11 @@ func Run(ctx context.Context, allowContext string, command []string) error { return fmt.Errorf("select free local port: %v", err) } - go portForwardForever(ctx, localPort, dindPort) - - err = awaitPortOpen(ctx, localPort) + cancel, err := startPortForward(ctx, localPort, dindPort) if err != nil { - return fmt.Errorf("wait for port forward to start: %v", err) + return fmt.Errorf("starting port-forward: %v", err) } + defer cancel() // execute command err = executeCommand(ctx, command, fmt.Sprintf("tcp://127.0.0.1:%d", localPort)) @@ -85,23 +85,40 @@ func freePort() (int, error) { return l.Addr().(*net.TCPAddr).Port, nil } -func portForwardForever(ctx context.Context, localPort, dindPort int) { - err := portForward(ctx, localPort, dindPort) - if err != nil { - log.Printf("port forward failed: %v", err) - } - - for { - select { - case <-ctx.Done(): - return - case <-time.After(10 * time.Millisecond): - err := portForward(ctx, localPort, dindPort) - if err != nil { - log.Printf("port forward failed: %v", err) +func startPortForward(ctx context.Context, localPort, dindPort int) (func(), error) { + ctx, cancel := context.WithCancel(ctx) + done := make(chan interface{}) + + go func() { + defer func() { + done <- struct{}{} + }() + + for { + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Millisecond): + err := portForward(ctx, localPort, dindPort) + isDone := len(ctx.Done()) > 0 || errors.Is(ctx.Err(), context.Canceled) + + if err != nil && !isDone { + log.Printf("port forward failed: %v", err) + } } } + }() + + err := awaitPortOpen(ctx, localPort) + if err != nil { + cancel() + return nil, fmt.Errorf("wait for port forward to start: %v", err) } + + return func() { + cancel() + <-done + }, nil } func portForward(ctx context.Context, localPort, dinnerPort int) error { diff --git a/pkg/up.go b/pkg/up.go index c28a398a..c82aa5a5 100644 --- a/pkg/up.go +++ b/pkg/up.go @@ -39,12 +39,11 @@ func Up(ctx context.Context, allowContext string, command []string) error { return fmt.Errorf("select free local port: %v", err) } - go portForwardForever(ctx, localPort, dindPort) - - err = awaitPortOpen(ctx, localPort) + cancel, err := startPortForward(ctx, localPort, dindPort) if err != nil { - return fmt.Errorf("wait for port forward to start: %v", err) + return fmt.Errorf("starting port-forward: %v", err) } + defer cancel() // execute command err = executeCommand(ctx, command, fmt.Sprintf("tcp://127.0.0.1:%d", localPort))