Skip to content

Commit

Permalink
Merge pull request #520 from ibuildthecloud/main
Browse files Browse the repository at this point in the history
chore: add progress output for builtins specifically sys.exec
  • Loading branch information
ibuildthecloud committed Jun 20, 2024
2 parents 889eff2 + e9c2bf9 commit 81d3b48
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 28 deletions.
73 changes: 51 additions & 22 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package builtin

import (
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -264,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) {
return SetDefaults(t), ok
}

func SysFind(_ context.Context, _ []string, input string) (string, error) {
func SysFind(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var result []string
var params struct {
Pattern string `json:"pattern,omitempty"`
Expand Down Expand Up @@ -305,7 +306,7 @@ func SysFind(_ context.Context, _ []string, input string) (string, error) {
return strings.Join(result, "\n"), nil
}

func SysExec(_ context.Context, env []string, input string) (string, error) {
func SysExec(_ context.Context, env []string, input string, progress chan<- string) (string, error) {
var params struct {
Command string `json:"command,omitempty"`
Directory string `json:"directory,omitempty"`
Expand All @@ -328,13 +329,30 @@ func SysExec(_ context.Context, env []string, input string) (string, error) {
cmd = exec.Command("/bin/sh", "-c", params.Command)
}

var (
out bytes.Buffer
pw = progressWriter{
out: progress,
}
combined = io.MultiWriter(&out, &pw)
)
cmd.Env = env
cmd.Dir = params.Directory
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, out), nil
cmd.Stdout = combined
cmd.Stderr = combined
if err := cmd.Run(); err != nil {
return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, &out), nil
}
return string(out), nil
return out.String(), nil
}

type progressWriter struct {
out chan<- string
}

func (pw *progressWriter) Write(p []byte) (n int, err error) {
pw.out <- string(p)
return len(p), nil
}

func getWorkspaceDir(envs []string) (string, error) {
Expand All @@ -347,7 +365,7 @@ func getWorkspaceDir(envs []string) (string, error) {
return "", fmt.Errorf("no workspace directory found in env")
}

func SysLs(_ context.Context, _ []string, input string) (string, error) {
func SysLs(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Dir string `json:"dir,omitempty"`
}
Expand Down Expand Up @@ -383,7 +401,7 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) {
return strings.Join(result, "\n"), nil
}

func SysRead(_ context.Context, _ []string, input string) (string, error) {
func SysRead(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
}
Expand Down Expand Up @@ -411,7 +429,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) {
return string(data), nil
}

func SysWrite(_ context.Context, _ []string, input string) (string, error) {
func SysWrite(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -443,7 +461,7 @@ func SysWrite(_ context.Context, _ []string, input string) (string, error) {
return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil
}

func SysAppend(_ context.Context, _ []string, input string) (string, error) {
func SysAppend(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filename string `json:"filename,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -489,7 +507,7 @@ func fixQueries(u string) string {
return url.String()
}

func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) {
func SysHTTPGet(_ context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
}
Expand Down Expand Up @@ -523,8 +541,8 @@ func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err erro
return string(data), nil
}

func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string, error) {
content, err := SysHTTPGet(ctx, env, input)
func SysHTTPHtml2Text(ctx context.Context, env []string, input string, progress chan<- string) (string, error) {
content, err := SysHTTPGet(ctx, env, input, progress)
if err != nil {
return "", err
}
Expand All @@ -533,7 +551,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string,
})
}

func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) {
func SysHTTPPost(ctx context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
Content string `json:"content,omitempty"`
Expand Down Expand Up @@ -569,7 +587,18 @@ func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err e
return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil
}

func SysGetenv(_ context.Context, env []string, input string) (string, error) {
func DiscardProgress() (progress chan<- string, closeFunc func()) {
ch := make(chan string)
go func() {
for range ch {
}
}()
return ch, func() {
close(ch)
}
}

func SysGetenv(_ context.Context, env []string, input string, _ chan<- string) (string, error) {
var params struct {
Name string `json:"name,omitempty"`
}
Expand Down Expand Up @@ -597,7 +626,7 @@ func invalidArgument(input string, err error) string {
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
}

func SysChatHistory(ctx context.Context, _ []string, _ string) (string, error) {
func SysChatHistory(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
engineContext, _ := engine.FromContext(ctx)

data, err := json.Marshal(engine.ChatHistory{
Expand Down Expand Up @@ -627,7 +656,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) {
return
}

func SysChatFinish(_ context.Context, _ []string, input string) (string, error) {
func SysChatFinish(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Message string `json:"return,omitempty"`
}
Expand All @@ -641,7 +670,7 @@ func SysChatFinish(_ context.Context, _ []string, input string) (string, error)
}
}

func SysAbort(_ context.Context, _ []string, input string) (string, error) {
func SysAbort(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Message string `json:"message,omitempty"`
}
Expand All @@ -651,7 +680,7 @@ func SysAbort(_ context.Context, _ []string, input string) (string, error) {
return "", fmt.Errorf("ABORT: %s", params.Message)
}

func SysRemove(_ context.Context, _ []string, input string) (string, error) {
func SysRemove(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Location string `json:"location,omitempty"`
}
Expand All @@ -670,7 +699,7 @@ func SysRemove(_ context.Context, _ []string, input string) (string, error) {
return fmt.Sprintf("Removed file: %s", params.Location), nil
}

func SysStat(_ context.Context, _ []string, input string) (string, error) {
func SysStat(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Filepath string `json:"filepath,omitempty"`
}
Expand All @@ -690,7 +719,7 @@ func SysStat(_ context.Context, _ []string, input string) (string, error) {
return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil
}

func SysDownload(_ context.Context, env []string, input string) (_ string, err error) {
func SysDownload(_ context.Context, env []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
URL string `json:"url,omitempty"`
Location string `json:"location,omitempty"`
Expand Down Expand Up @@ -763,6 +792,6 @@ func SysDownload(_ context.Context, env []string, input string) (_ string, err e
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
}

func SysTimeNow(context.Context, []string, string) (string, error) {
func SysTimeNow(context.Context, []string, string, chan<- string) (string, error) {
return time.Now().Format(time.RFC3339), nil
}
6 changes: 4 additions & 2 deletions pkg/builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ import (
)

func TestSysGetenv(t *testing.T) {
p, c := DiscardProgress()
defer c()
v, err := SysGetenv(context.Background(), []string{
"MAGIC=VALUE",
}, `{"name":"MAGIC"}`)
}, `{"name":"MAGIC"}`, nil)
require.NoError(t, err)
autogold.Expect("VALUE").Equal(t, v)

v, err = SysGetenv(context.Background(), []string{
"MAGIC=VALUE",
}, `{"name":"MAGIC2"}`)
}, `{"name":"MAGIC2"}`, p)
require.NoError(t, err)
autogold.Expect("MAGIC2 is not set or has no value").Equal(t, v)
}
Expand Down
26 changes: 25 additions & 1 deletion pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"runtime"
"sort"
"strings"
"sync"

"github.com/google/shlex"
"github.com/gptscript-ai/gptscript/pkg/counter"
Expand Down Expand Up @@ -64,7 +65,30 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
"input": input,
},
}
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input)

var (
progress = make(chan string)
wg sync.WaitGroup
)
wg.Add(1)
defer wg.Wait()
defer close(progress)
go func() {
defer wg.Done()
buf := strings.Builder{}
for line := range progress {
buf.WriteString(line)
e.Progress <- types.CompletionStatus{
CompletionID: id,
PartialResponse: &types.CompletionMessage{
Role: types.CompletionMessageRoleTypeAssistant,
Content: types.Text(buf.String()),
},
}
}
}()

return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
}

var instructions []string
Expand Down
3 changes: 2 additions & 1 deletion pkg/prompt/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ func GetModelProviderCredential(ctx context.Context, credStore credentials.Crede
if exists {
k = cred.Env[env]
} else {
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message))
// we know progress isn't used so pass as nil
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message), nil)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
return string(data), err
}

func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) {
func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) {
var params struct {
Message string `json:"message,omitempty"`
Fields string `json:"fields,omitempty"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (p Program) SetBlocking() Program {
return p
}

type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error)
type BuiltinFunc func(ctx context.Context, env []string, input string, progress chan<- string) (string, error)

type Parameters struct {
Name string `json:"name,omitempty"`
Expand Down

0 comments on commit 81d3b48

Please sign in to comment.