From 398fa80035fcb7cb03f7c826322fb921bb98abec Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Tue, 5 Nov 2024 19:59:14 -0500 Subject: [PATCH] chore: allow setting of dataset tool in SDK server config Signed-off-by: Donnie Adams --- pkg/cli/sdk_server.go | 2 ++ pkg/sdkserver/datasets.go | 27 ++++++++++++++------------- pkg/sdkserver/routes.go | 10 +++++----- pkg/sdkserver/server.go | 13 +++++++++---- 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/pkg/cli/sdk_server.go b/pkg/cli/sdk_server.go index 5ce65305..42f0f949 100644 --- a/pkg/cli/sdk_server.go +++ b/pkg/cli/sdk_server.go @@ -11,6 +11,7 @@ import ( type SDKServer struct { *GPTScript + DatasetTool string `usage:"Tool to use for datasets"` WorkspaceTool string `usage:"Tool to use for workspace"` } @@ -38,6 +39,7 @@ func (c *SDKServer) Run(cmd *cobra.Command, _ []string) error { Options: opts, ListenAddress: c.ListenAddress, Debug: c.Debug, + DatasetTool: c.DatasetTool, WorkspaceTool: c.WorkspaceTool, }) } diff --git a/pkg/sdkserver/datasets.go b/pkg/sdkserver/datasets.go index 1a547953..5db90bf7 100644 --- a/pkg/sdkserver/datasets.go +++ b/pkg/sdkserver/datasets.go @@ -10,6 +10,14 @@ import ( "github.com/gptscript-ai/gptscript/pkg/loader" ) +func (s *server) getDatasetTool(req datasetRequest) string { + if req.DatasetToolRepo != "" { + return req.DatasetToolRepo + } + + return s.datasetTool +} + type datasetRequest struct { Input string `json:"input"` WorkspaceID string `json:"workspaceID"` @@ -38,13 +46,6 @@ func (r datasetRequest) opts(o gptscript.Options) gptscript.Options { return opts } -func (r datasetRequest) getToolRepo() string { - if r.DatasetToolRepo != "" { - return r.DatasetToolRepo - } - return "github.com/otto8-ai/datasets" -} - func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { logger := gcontext.GetLogger(r.Context()) @@ -65,7 +66,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{ + prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "List Datasets", loader.Options{ Cache: g.Cache, }) @@ -126,7 +127,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), req.getToolRepo(), "Create Dataset", loader.Options{ + prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Create Dataset", loader.Options{ Cache: g.Cache, }) @@ -195,7 +196,7 @@ func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Element", loader.Options{ + prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Element", loader.Options{ Cache: g.Cache, }) if err != nil { @@ -262,7 +263,7 @@ func (s *server) addDatasetElements(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Elements", loader.Options{ + prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Add Elements", loader.Options{ Cache: g.Cache, }) if err != nil { @@ -327,7 +328,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Elements", loader.Options{ + prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "List Elements", loader.Options{ Cache: g.Cache, }) if err != nil { @@ -390,7 +391,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element SDK", loader.Options{ + prg, err := loader.Program(r.Context(), s.getDatasetTool(req), "Get Element SDK", loader.Options{ Cache: g.Cache, }) if err != nil { diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index e9b1cca8..ea7fdb09 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -26,11 +26,11 @@ import ( ) type server struct { - gptscriptOpts gptscript.Options - address, token string - workspaceTool string - client *gptscript.GPTScript - events *broadcaster.Broadcaster[event] + gptscriptOpts gptscript.Options + address, token string + datasetTool, workspaceTool string + client *gptscript.GPTScript + events *broadcaster.Broadcaster[event] runtimeManager engine.RuntimeManager diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index f0c61940..7d98ae60 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -26,10 +26,10 @@ import ( type Options struct { gptscript.Options - ListenAddress string - WorkspaceTool string - Debug bool - DisableServerErrorLogging bool + ListenAddress string + DatasetTool, WorkspaceTool string + Debug bool + DisableServerErrorLogging bool } // Run will start the server and block until the server is shut down. @@ -108,6 +108,7 @@ func run(ctx context.Context, listener net.Listener, opts Options) error { gptscriptOpts: opts.Options, address: listener.Addr().String(), token: token, + datasetTool: opts.DatasetTool, workspaceTool: opts.WorkspaceTool, client: g, events: events, @@ -159,6 +160,7 @@ func complete(opts ...Options) Options { for _, opt := range opts { result.Options = gptscript.Complete(result.Options, opt.Options) result.ListenAddress = types.FirstSet(opt.ListenAddress, result.ListenAddress) + result.DatasetTool = types.FirstSet(opt.DatasetTool, result.DatasetTool) result.WorkspaceTool = types.FirstSet(opt.WorkspaceTool, result.WorkspaceTool) result.Debug = types.FirstSet(opt.Debug, result.Debug) result.DisableServerErrorLogging = types.FirstSet(opt.DisableServerErrorLogging, result.DisableServerErrorLogging) @@ -171,6 +173,9 @@ func complete(opts ...Options) Options { if result.WorkspaceTool == "" { result.WorkspaceTool = "github.com/gptscript-ai/workspace-provider" } + if result.DatasetTool == "" { + result.DatasetTool = "github.com/otto8-ai/datasets" + } return result }