diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index c4d68e47..b3dd7962 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -146,7 +146,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error { defer gptScript.Close() if r.ListModels { - models, err := gptScript.ListModels(cmd.Context()) + models, err := gptScript.ListModels(cmd.Context(), args...) if err != nil { return err } diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index 6e593880..323917a7 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -103,6 +103,6 @@ func (g *GPTScript) GetModel() engine.Model { return g.Registry } -func (g *GPTScript) ListModels(ctx context.Context) ([]string, error) { - return g.Registry.ListModels(ctx) +func (g *GPTScript) ListModels(ctx context.Context, providers ...string) ([]string, error) { + return g.Registry.ListModels(ctx, providers...) } diff --git a/pkg/llm/registry.go b/pkg/llm/registry.go index 166d2197..5ab99087 100644 --- a/pkg/llm/registry.go +++ b/pkg/llm/registry.go @@ -11,7 +11,7 @@ import ( type Client interface { Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) - ListModels(ctx context.Context) (result []string, _ error) + ListModels(ctx context.Context, providers ...string) (result []string, _ error) Supports(ctx context.Context, modelName string) (bool, error) } @@ -28,9 +28,9 @@ func (r *Registry) AddClient(client Client) error { return nil } -func (r *Registry) ListModels(ctx context.Context) (result []string, _ error) { +func (r *Registry) ListModels(ctx context.Context, providers ...string) (result []string, _ error) { for _, v := range r.clients { - models, err := v.ListModels(ctx) + models, err := v.ListModels(ctx, providers...) if err != nil { return nil, err } diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 3aeace0c..884acb08 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -127,7 +127,12 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) { return slices.Contains(models, modelName), nil } -func (c *Client) ListModels(ctx context.Context) (result []string, _ error) { +func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) { + // Only serve if providers is empty or "" is in the list + if len(providers) != 0 && !slices.Contains(providers, "") { + return nil, nil + } + models, err := c.c.ListModels(ctx) if err != nil { return nil, err diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 9bb0c634..6dcd3e60 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -15,7 +15,6 @@ import ( "github.com/gptscript-ai/gptscript/pkg/openai" "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/types" - "golang.org/x/exp/maps" ) type Client struct { @@ -49,13 +48,23 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return client.Call(ctx, messageRequest, status) } -func (c *Client) ListModels(_ context.Context) (result []string, _ error) { - c.clientsLock.Lock() - defer c.clientsLock.Unlock() +func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) { + for _, provider := range providers { + client, err := c.load(ctx, provider) + if err != nil { + return nil, err + } + models, err := client.ListModels(ctx, "") + if err != nil { + return nil, err + } + for _, model := range models { + result = append(result, model+" from "+provider) + } + } - keys := maps.Keys(c.models) - sort.Strings(keys) - return keys, nil + sort.Strings(result) + return } func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {