Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add progress output for builtins specifically sys.exec #520

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package builtin

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -264,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) {
return SetDefaults(t), ok
}

func SysFind(_ context.Context, _ []string, input string) (string, error) {
func SysFind(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var result []string
var params struct {
Pattern string `json:"pattern,omitempty"`
Expand Down Expand Up @@ -305,7 +306,7 @@ func SysFind(_ context.Context, _ []string, input string) (string, error) {
return strings.Join(result, "\n"), nil
}

func SysExec(_ context.Context, env []string, input string) (string, error) {
func SysExec(_ context.Context, env []string, input string, progress chan<- string) (string, error) {
var params struct {
Command string `json:"command,omitempty"`
Directory string `json:"directory,omitempty"`
Expand All @@ -328,13 +329,30 @@ func SysExec(_ context.Context, env []string, input string) (string, error) {
cmd = exec.Command("/bin/sh", "-c", params.Command)
}

var (
out bytes.Buffer
pw = progressWriter{
out: progress,
}
combined = io.MultiWriter(&out, &pw)
)
cmd.Env = env
cmd.Dir = params.Directory
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, out), nil
cmd.Stdout = combined
cmd.Stderr = combined
if err := cmd.Run(); err != nil {
return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, &out), nil
}
return string(out), nil
return out.String(), nil
}

type progressWriter struct {
out chan<- string
}

func (pw *progressWriter) Write(p []byte) (n int, err error) {
pw.out <- string(p)
return len(p), nil
}

func getWorkspaceDir(envs []string) (string, error) {
Expand All @@ -347,7 +365,7 @@ func getWorkspaceDir(envs []string) (string, error) {
return "", fmt.Errorf("no workspace directory found in env")
}

func SysLs(_ context.Context, _ []string, input string) (string, error) {
func SysLs(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Dir string `json:"dir,omitempty"`
}
Expand Down Expand Up @@ -383,7 +401,7 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) {
return strings.Join(result, "\n"), nil
}

func SysRead(_ context.Context, _ []string, input string) (string, error) {
func SysRead(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
}
Expand Down Expand Up @@ -411,7 +429,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) {
return string(data), nil
}

func SysWrite(_ context.Context, _ []string, input string) (string, error) {
func SysWrite(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -443,7 +461,7 @@ func SysWrite(_ context.Context, _ []string, input string) (string, error) {
return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil
}

func SysAppend(_ context.Context, _ []string, input string) (string, error) {
func SysAppend(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -489,7 +507,7 @@ func fixQueries(u string) string {
return url.String()
}

func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) {
func SysHTTPGet(_ context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
}
Expand Down Expand Up @@ -523,8 +541,8 @@ func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err erro
return string(data), nil
}

func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string, error) {
content, err := SysHTTPGet(ctx, env, input)
func SysHTTPHtml2Text(ctx context.Context, env []string, input string, progress chan<- string) (string, error) {
content, err := SysHTTPGet(ctx, env, input, progress)
if err != nil {
return "", err
}
Expand All @@ -533,7 +551,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string,
})
}

func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) {
func SysHTTPPost(ctx context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -569,7 +587,18 @@ func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err e
return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil
}

func SysGetenv(_ context.Context, env []string, input string) (string, error) {
func DiscardProgress() (progress chan<- string, closeFunc func()) {
ch := make(chan string)
go func() {
for range ch {
}
}()
return ch, func() {
close(ch)
}
}

func SysGetenv(_ context.Context, env []string, input string, _ chan<- string) (string, error) {
var params struct {
Name string `json:"name,omitempty"`
}
Expand Down Expand Up @@ -597,7 +626,7 @@ func invalidArgument(input string, err error) string {
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
}

func SysChatHistory(ctx context.Context, _ []string, _ string) (string, error) {
func SysChatHistory(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
engineContext, _ := engine.FromContext(ctx)

data, err := json.Marshal(engine.ChatHistory{
Expand Down Expand Up @@ -627,7 +656,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) {
return
}

func SysChatFinish(_ context.Context, _ []string, input string) (string, error) {
func SysChatFinish(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Message string `json:"return,omitempty"`
}
Expand All @@ -641,7 +670,7 @@ func SysChatFinish(_ context.Context, _ []string, input string) (string, error)
}
}

func SysAbort(_ context.Context, _ []string, input string) (string, error) {
func SysAbort(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Message string `json:"message,omitempty"`
}
Expand All @@ -651,7 +680,7 @@ func SysAbort(_ context.Context, _ []string, input string) (string, error) {
return "", fmt.Errorf("ABORT: %s", params.Message)
}

func SysRemove(_ context.Context, _ []string, input string) (string, error) {
func SysRemove(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Location string `json:"location,omitempty"`
}
Expand All @@ -670,7 +699,7 @@ func SysRemove(_ context.Context, _ []string, input string) (string, error) {
return fmt.Sprintf("Removed file: %s", params.Location), nil
}

func SysStat(_ context.Context, _ []string, input string) (string, error) {
func SysStat(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filepath string `json:"filepath,omitempty"`
}
Expand All @@ -690,7 +719,7 @@ func SysStat(_ context.Context, _ []string, input string) (string, error) {
return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil
}

func SysDownload(_ context.Context, env []string, input string) (_ string, err error) {
func SysDownload(_ context.Context, env []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
Location string `json:"location,omitempty"`
Expand Down Expand Up @@ -763,6 +792,6 @@ func SysDownload(_ context.Context, env []string, input string) (_ string, err e
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
}

func SysTimeNow(context.Context, []string, string) (string, error) {
func SysTimeNow(context.Context, []string, string, chan<- string) (string, error) {
return time.Now().Format(time.RFC3339), nil
}
6 changes: 4 additions & 2 deletions pkg/builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ import (
)

func TestSysGetenv(t *testing.T) {
p, c := DiscardProgress()
defer c()
v, err := SysGetenv(context.Background(), []string{
"MAGIC=VALUE",
}, `{"name":"MAGIC"}`)
}, `{"name":"MAGIC"}`, nil)
require.NoError(t, err)
autogold.Expect("VALUE").Equal(t, v)

v, err = SysGetenv(context.Background(), []string{
"MAGIC=VALUE",
}, `{"name":"MAGIC2"}`)
}, `{"name":"MAGIC2"}`, p)
require.NoError(t, err)
autogold.Expect("MAGIC2 is not set or has no value").Equal(t, v)
}
Expand Down
26 changes: 25 additions & 1 deletion pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"runtime"
"sort"
"strings"
"sync"

"github.com/google/shlex"
"github.com/gptscript-ai/gptscript/pkg/counter"
Expand Down Expand Up @@ -64,7 +65,30 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
"input": input,
},
}
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input)

var (
progress = make(chan string)
wg sync.WaitGroup
)
wg.Add(1)
defer wg.Wait()
defer close(progress)
go func() {
defer wg.Done()
buf := strings.Builder{}
for line := range progress {
buf.WriteString(line)
e.Progress <- types.CompletionStatus{
CompletionID: id,
PartialResponse: &types.CompletionMessage{
Role: types.CompletionMessageRoleTypeAssistant,
Content: types.Text(buf.String()),
},
}
}
}()

return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
}

var instructions []string
Expand Down
3 changes: 2 additions & 1 deletion pkg/prompt/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ func GetModelProviderCredential(ctx context.Context, credStore credentials.Crede
if exists {
k = cred.Env[env]
} else {
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message))
// we know progress isn't used so pass as nil
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message), nil)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
return string(data), err
}

func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) {
func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
Message string `json:"message,omitempty"`
Fields string `json:"fields,omitempty"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (p Program) SetBlocking() Program {
return p
}

type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error)
type BuiltinFunc func(ctx context.Context, env []string, input string, progress chan<- string) (string, error)

type Parameters struct {
Name string `json:"name,omitempty"`
Expand Down