Skip to content

Commit

Permalink
feat: unify prompts (#1344)
Browse files Browse the repository at this point in the history
  • Loading branch information
pd93 committed Oct 7, 2023
1 parent 222cd8c commit dc77286
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 56 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Enabled the `--yes` flag for the
[Remote Taskfiles experiment](https://taskfile.dev/experiments/remote-taskfiles)
(#1344 by @pd93).
- Add ability to set `watch: true` in a task to automatically run it in watch
mode (#231, #1361 by @andreynering).
- Fixed a bug on the watch mode where paths that contained `.git` (like
Expand Down
10 changes: 10 additions & 0 deletions docs/docs/experiments/remote_taskfiles.md
Expand Up @@ -54,6 +54,16 @@ Taskfiles:
code `104` (not trusted) and nothing will run. If you accept the prompt, the
checksum will be updated and the remote Taskfile will run.

Sometimes you need to run Task in an environment that does not have an
interactive terminal, so you are not able to accept a prompt. In these cases you
are able to tell task to accept these prompts automatically by using the `--yes`
flag. Before enabling this flag, you should:

1. Be sure that you trust the source and contents of the remote Taskfile.
2. Consider using a pinned version of the remote Taskfile (e.g. A link
containing a commit hash) to prevent Task from automatically accepting a
prompt that says a remote Taskfile has changed.

Task currently supports both `http` and `https` URLs. However, the `http`
requests will not execute by default unless you run the task with the
`--insecure` flag. This is to protect you from accidentally running a remote
Expand Down
47 changes: 37 additions & 10 deletions internal/logger/logger.go
Expand Up @@ -9,6 +9,14 @@ import (

"github.com/fatih/color"
"golang.org/x/exp/slices"

"github.com/go-task/task/v3/errors"
"github.com/go-task/task/v3/internal/term"
)

var (
ErrPromptCancelled = errors.New("prompt cancelled")
ErrNoTerminal = errors.New("no terminal")
)

type (
Expand Down Expand Up @@ -59,10 +67,13 @@ func envColor(env string, defaultColor color.Attribute) color.Attribute {
// Logger is just a wrapper that prints stuff to STDOUT or STDERR,
// with optional color.
type Logger struct {
Stdout io.Writer
Stderr io.Writer
Verbose bool
Color bool
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
Verbose bool
Color bool
AssumeYes bool
AssumeTerm bool // Used for testing
}

// Outf prints stuff to STDOUT.
Expand Down Expand Up @@ -108,16 +119,32 @@ func (l *Logger) VerboseErrf(color Color, s string, args ...any) {
}
}

func (l *Logger) Prompt(color Color, s string, defaultValue string, continueValues ...string) (bool, error) {
func (l *Logger) Prompt(color Color, prompt string, defaultValue string, continueValues ...string) error {
if l.AssumeYes {
l.Outf(color, "%s [assuming yes]\n", prompt)
return nil
}

if !l.AssumeTerm && !term.IsTerminal() {
return ErrNoTerminal
}

if len(continueValues) == 0 {
return false, nil
return errors.New("no continue values provided")
}
l.Outf(color, "%s [%s/%s]\n", s, strings.ToLower(continueValues[0]), strings.ToUpper(defaultValue))
reader := bufio.NewReader(os.Stdin)

l.Outf(color, "%s [%s/%s]\n", prompt, strings.ToLower(continueValues[0]), strings.ToUpper(defaultValue))

reader := bufio.NewReader(l.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
return false, err
return err
}

input = strings.TrimSpace(strings.ToLower(input))
return slices.Contains(continueValues, input), nil
if !slices.Contains(continueValues, input) {
return ErrPromptCancelled
}

return nil
}
11 changes: 7 additions & 4 deletions setup.go
Expand Up @@ -157,10 +157,13 @@ func (e *Executor) setupStdFiles() {

func (e *Executor) setupLogger() {
e.Logger = &logger.Logger{
Stdout: e.Stdout,
Stderr: e.Stderr,
Verbose: e.Verbose,
Color: e.Color,
Stdin: e.Stdin,
Stdout: e.Stdout,
Stderr: e.Stderr,
Verbose: e.Verbose,
Color: e.Color,
AssumeYes: e.AssumeYes,
AssumeTerm: e.AssumeTerm,
}
}

Expand Down
29 changes: 6 additions & 23 deletions task.go
@@ -1,13 +1,11 @@
package task

import (
"bufio"
"context"
"fmt"
"io"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
Expand All @@ -23,7 +21,6 @@ import (
"github.com/go-task/task/v3/internal/sort"
"github.com/go-task/task/v3/internal/summary"
"github.com/go-task/task/v3/internal/templater"
"github.com/go-task/task/v3/internal/term"
"github.com/go-task/task/v3/taskfile"

"github.com/sajari/fuzzy"
Expand All @@ -37,11 +34,6 @@ const (
MaximumTaskCall = 1000
)

func shouldPromptContinue(input string) bool {
input = strings.ToLower(strings.TrimSpace(input))
return slices.Contains([]string{"y", "yes"}, input)
}

// Executor executes a Taskfile
type Executor struct {
Taskfile *taskfile.Taskfile
Expand All @@ -58,13 +50,13 @@ type Executor struct {
Verbose bool
Silent bool
AssumeYes bool
AssumeTerm bool // Used for testing
Dry bool
Summary bool
Parallel bool
Color bool
Concurrency int
Interval time.Duration
AssumesTerm bool

Stdin io.Reader
Stdout io.Writer
Expand Down Expand Up @@ -182,22 +174,13 @@ func (e *Executor) RunTask(ctx context.Context, call taskfile.Call) error {
release := e.acquireConcurrencyLimit()
defer release()

if t.Prompt != "" && !e.AssumeYes {
if !e.AssumesTerm && !term.IsTerminal() {
if t.Prompt != "" {
if err := e.Logger.Prompt(logger.Yellow, t.Prompt, "n", "y", "yes"); errors.Is(err, logger.ErrNoTerminal) {
return &errors.TaskCancelledNoTerminalError{TaskName: call.Task}
}

e.Logger.Outf(logger.Yellow, "task: %q [y/N]: ", t.Prompt)

reader := bufio.NewReader(e.Stdin)
userInput, err := reader.ReadString('\n')
if err != nil {
return err
}

userInput = strings.ToLower(strings.TrimSpace(userInput))
if !shouldPromptContinue(userInput) {
} else if errors.Is(err, logger.ErrPromptCancelled) {
return &errors.TaskCancelledByUserError{TaskName: call.Task}
} else if err != nil {
return err
}
}

Expand Down
20 changes: 10 additions & 10 deletions task_test.go
Expand Up @@ -681,11 +681,11 @@ func TestPromptInSummary(t *testing.T) {
inBuff.Write([]byte(test.input))

e := task.Executor{
Dir: dir,
Stdin: &inBuff,
Stdout: &outBuff,
Stderr: &errBuff,
AssumesTerm: true,
Dir: dir,
Stdin: &inBuff,
Stdout: &outBuff,
Stderr: &errBuff,
AssumeTerm: true,
}
require.NoError(t, e.Setup())

Expand All @@ -709,11 +709,11 @@ func TestPromptWithIndirectTask(t *testing.T) {
inBuff.Write([]byte("y\n"))

e := task.Executor{
Dir: dir,
Stdin: &inBuff,
Stdout: &outBuff,
Stderr: &errBuff,
AssumesTerm: true,
Dir: dir,
Stdin: &inBuff,
Stdout: &outBuff,
Stderr: &errBuff,
AssumeTerm: true,
}
require.NoError(t, e.Setup())

Expand Down
18 changes: 9 additions & 9 deletions taskfile/read/taskfile.go
Expand Up @@ -86,19 +86,19 @@ func readTaskfile(
checksum := checksum(b)
cachedChecksum := cache.readChecksum(node)

// If the checksum doesn't exist, prompt the user to continue
var msg string
if cachedChecksum == "" {
if cont, err := l.Prompt(logger.Yellow, fmt.Sprintf("The task you are attempting to run depends on the remote Taskfile at %q.\n--- Make sure you trust the source of this Taskfile before continuing ---\nContinue?", node.Location()), "n", "y", "yes"); err != nil {
return nil, err
} else if !cont {
return nil, &errors.TaskfileNotTrustedError{URI: node.Location()}
}
// If the checksum doesn't exist, prompt the user to continue
msg = fmt.Sprintf("The task you are attempting to run depends on the remote Taskfile at %q.\n--- Make sure you trust the source of this Taskfile before continuing ---\nContinue?", node.Location())
} else if checksum != cachedChecksum {
// If there is a cached hash, but it doesn't match the expected hash, prompt the user to continue
if cont, err := l.Prompt(logger.Yellow, fmt.Sprintf("The Taskfile at %q has changed since you last used it!\n--- Make sure you trust the source of this Taskfile before continuing ---\nContinue?", node.Location()), "n", "y", "yes"); err != nil {
return nil, err
} else if !cont {
msg = fmt.Sprintf("The Taskfile at %q has changed since you last used it!\n--- Make sure you trust the source of this Taskfile before continuing ---\nContinue?", node.Location())
}
if msg != "" {
if err := l.Prompt(logger.Yellow, msg, "n", "y", "yes"); errors.Is(err, logger.ErrPromptCancelled) {
return nil, &errors.TaskfileNotTrustedError{URI: node.Location()}
} else if err != nil {
return nil, err
}
}

Expand Down

0 comments on commit dc77286

Please sign in to comment.