Skip to content

Commit

Permalink
fix: prompt should exit on io.EOF
Browse files Browse the repository at this point in the history
Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com>
  • Loading branch information
Benehiko committed Feb 27, 2024
1 parent e662a1e commit 135075d
Show file tree
Hide file tree
Showing 14 changed files with 127 additions and 41 deletions.
2 changes: 1 addition & 1 deletion cli/command/builder/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ func TestBuilderPromptTermination(t *testing.T) {
// set a fake reader so that our kill signal reaches the prompt before the prompt reads from stdin
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := NewPruneCommand(cli)
test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
}
2 changes: 1 addition & 1 deletion cli/command/container/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ func TestContainerPrunePromptTermination(t *testing.T) {
// set a fake reader so that our kill signal reaches the prompt before the prompt reads from stdin
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := NewPruneCommand(cli)
test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
}
2 changes: 1 addition & 1 deletion cli/command/image/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,5 @@ func TestPrunePromptTermination(t *testing.T) {
// set a fake reader so that our kill signal reaches the prompt before the prompt reads from stdin
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := NewPruneCommand(cli)
test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
}
2 changes: 1 addition & 1 deletion cli/command/network/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ func TestNetworkPrunePromptTermination(t *testing.T) {
// set a fake reader so that our kill signal reaches the prompt before the prompt reads from stdin
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := NewPruneCommand(cli)
test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
}
2 changes: 1 addition & 1 deletion cli/command/network/remove_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,5 @@ func TestNetworkRemovePromptTermination(t *testing.T) {
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := newRemoveCommand(cli)
cmd.SetArgs([]string{"existing-network"})
test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
}
13 changes: 11 additions & 2 deletions cli/command/plugin/upgrade_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package plugin

import (
"context"
"io"
"testing"
"time"

"github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types"
Expand All @@ -12,6 +14,12 @@ import (
)

func TestUpgradePromptTermination(t *testing.T) {
t.Cleanup(func() {
})

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

cli := test.NewFakeCli(&fakeClient{
pluginUpgradeFunc: func(name string, options types.PluginInstallOptions) (io.ReadCloser, error) {
return nil, errors.New("should not be called")
Expand All @@ -25,11 +33,12 @@ func TestUpgradePromptTermination(t *testing.T) {
}, []byte{}, nil
},
})
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := newUpgradeCommand(cli)
// need to set a remote address that does not match the plugin
// reference sent by the `pluginInspectFunc`
cmd.SetArgs([]string{"plugin", "localhost:5000/foo/bar:v1.0.0"})
test.TerminatePrompt(t, cmd, cli, func(t *testing.T, err error) {
cmd.SetArgs([]string{"foo/bar", "localhost:5000/foo/bar:v1.0.0"})
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.Error(t, err, "canceling upgrade request")
})
Expand Down
2 changes: 1 addition & 1 deletion cli/command/system/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ func TestSystemPrunePromptTermination(t *testing.T) {
// set a fake reader so that our kill signal reaches the prompt before the prompt reads from stdin
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := newPruneCommand(cli)
test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
}
3 changes: 1 addition & 2 deletions cli/command/trust/revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package trust
import (
"context"
"fmt"
"os"

"github.com/docker/cli/cli"
"github.com/docker/cli/cli/command"
Expand Down Expand Up @@ -44,7 +43,7 @@ func revokeTrust(ctx context.Context, dockerCLI command.Cli, remote string, opti
return fmt.Errorf("cannot use a digest reference for IMAGE:TAG")
}
if imgRefAndAuth.Tag() == "" && !options.forceYes {
deleteRemote := command.PromptForConfirmation(ctx, os.Stdin, dockerCLI.Out(), fmt.Sprintf("Please confirm you would like to delete all signature data for %s?", remote))
deleteRemote := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), fmt.Sprintf("Please confirm you would like to delete all signature data for %s?", remote))
if !deleteRemote {
fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n")
return nil
Expand Down
2 changes: 1 addition & 1 deletion cli/command/trust/revoke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func TestRevokeTrustPromptTermination(t *testing.T) {
cmd := newRevokeCommand(cli)
cmd.SetArgs([]string{"example/trust-demo"})

test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
assert.Equal(t, cli.ErrBuffer().String(), "")
golden.Assert(t, cli.OutBuffer().String(), "trust-revoke-prompt-termination.golden")
}
14 changes: 9 additions & 5 deletions cli/command/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,20 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
defer notifyCancel()

go func() {
reader := bufio.NewScanner(ins)
for {
defer notifyCancel()
scanner := bufio.NewScanner(ins)
for scanner.Scan() {
select {
case <-notifyCtx.Done():
result <- false
return
default:
if reader.Scan() {
result <- (strings.EqualFold(strings.ToLower(reader.Text()), "y"))
return
var r bool
if strings.EqualFold(strings.ToLower(scanner.Text()), "y") {
r = true
}
result <- r
return
}
}
}()
Expand Down
91 changes: 71 additions & 20 deletions cli/command/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,33 +78,84 @@ func TestValidateOutputPath(t *testing.T) {
}

func TestPromptForConfirmation(t *testing.T) {
buf := new(bytes.Buffer)
w := bufio.NewWriter(buf)

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

in := test.NewFakeStreamIn(ctx, time.Second*2)
t.Run("case=terminate the prompt with SIGINT", func(t *testing.T) {
wroteHook := make(chan struct{}, 1)
buf := new(bytes.Buffer)
bufioWriter := bufio.NewWriter(buf)
w := test.NewWriterWithHook(bufioWriter, func() {
wroteHook <- struct{}{}
})

t.Cleanup(func() {
in.Close()
in := test.NewFakeStreamIn(ctx, time.Second*2)
t.Cleanup(func() {
assert.NilError(t, in.Close())
})
result := make(chan bool, 1)
go func() {
defer close(result)
result <- command.PromptForConfirmation(ctx, in, w, "")
}()

// wait for the Prompt to write to the buffer
pollForPromptOutput(ctx, t, wroteHook)
assert.NilError(t, bufioWriter.Flush())
assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]")

syscall.Kill(syscall.Getpid(), syscall.SIGINT)
select {
case r := <-result:
assert.Check(t, !r)
case <-time.After(100 * time.Millisecond):
t.Fatal("PromptForConfirmation did not return after SIGINT")
}
})

result := make(chan bool, 1)
go func() {
defer close(result)
result <- command.PromptForConfirmation(ctx, in, w, "")
}()
t.Run("case=prompt should return on io.EOF", func(t *testing.T) {
buf := new(bytes.Buffer)
wroteHook := make(chan struct{}, 1)
bufioWriter := bufio.NewWriter(buf)
w := test.NewWriterWithHook(bufioWriter, func() {
wroteHook <- struct{}{}
})
in := test.NewFakeStreamIn(ctx, 0)

result := make(chan bool, 1)
go func() {
result <- command.PromptForConfirmation(ctx, in, w, "")
}()

// wait for the prompt to write to the buffer
pollForPromptOutput(ctx, t, wroteHook)

assert.NilError(t, bufioWriter.Flush())
assert.Check(t, strings.Contains(buf.String(), "Are you sure you want to proceed? [y/N]"))
assert.NilError(t, in.Close())

select {
case r := <-result:
assert.Check(t, !r)
case <-time.After(100 * time.Millisecond):
t.Fatal("PromptForConfirmation did not return after io.EOF")
}
})
}

func pollForPromptOutput(ctx context.Context, t *testing.T, wroteHook <-chan struct{}) {
t.Helper()

time.Sleep(100 * time.Millisecond)
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()

syscall.Kill(syscall.Getpid(), syscall.SIGINT)
select {
case r := <-result:
assert.Check(t, !r)
case <-time.After(1 * time.Millisecond):
t.Fatal("PromptForConfirmation did not return after SIGINT")
for {
select {
case <-ctx.Done():
t.Fatal("Buffered output was not written to before ctx was cancelled")
return
case <-wroteHook:
return
}
}
assert.NilError(t, w.Flush())
assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]")
}
2 changes: 1 addition & 1 deletion cli/command/volume/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func TestVolumePrunePromptTerminate(t *testing.T) {
// set a fake reader so that our kill signal reaches the prompt before the prompt reads from stdin
cli.SetIn(test.NewFakeStreamIn(ctx, time.Second*2))
cmd := NewPruneCommand(cli)
test.TerminatePrompt(t, cmd, cli, nil)
test.TerminatePrompt(ctx, t, cmd, cli, nil)

golden.Assert(t, cli.OutBuffer().String(), "volume-prune-terminate.golden")
}
9 changes: 5 additions & 4 deletions internal/test/cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package test

import (
"context"
"syscall"
"testing"
"time"
Expand All @@ -9,16 +10,16 @@ import (
"gotest.tools/v3/assert"
)

func TerminatePrompt(t *testing.T, cmd *cobra.Command, cli *FakeCli, assertFunc func(*testing.T, error)) {
func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli *FakeCli, assertFunc func(*testing.T, error)) {
t.Helper()

errChan := make(chan error)
defer close(errChan)

go func() {
errChan <- cmd.Execute()
errChan <- cmd.ExecuteContext(ctx)
}()

// wait for the prompt to be displayed
time.Sleep(100 * time.Millisecond)

syscall.Kill(syscall.Getpid(), syscall.SIGINT)
Expand All @@ -31,7 +32,7 @@ func TerminatePrompt(t *testing.T, cmd *cobra.Command, cli *FakeCli, assertFunc
return
}
assert.NilError(t, err)
case <-time.After(100 * time.Millisecond):
case <-time.After(1000 * time.Millisecond):
t.Fatalf("command %s did not return after SIGINT", cmd.Name())
}

Expand Down
22 changes: 22 additions & 0 deletions internal/test/writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package test

import (
"io"
)

type WriterWithHook struct {
actualWriter io.Writer
hook func()
}

// Write implements io.Writer.
func (w *WriterWithHook) Write(p []byte) (n int, err error) {
defer w.hook()
return w.actualWriter.Write(p)
}

var _ io.Writer = (*WriterWithHook)(nil)

func NewWriterWithHook(actualWriter io.Writer, hook func()) *WriterWithHook {
return &WriterWithHook{actualWriter: actualWriter, hook: hook}
}

0 comments on commit 135075d

Please sign in to comment.