From 57b9ca1ddd20eb5edea0dc625efbeda2d6a1c620 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 11:11:45 -0500 Subject: [PATCH 01/28] Pull client earlier, pass to each command Trying to make it an injected thing for easier testing. --- cmd/list/list.go | 15 +++------------ cmd/root.go | 19 ++++++++++++++++--- cmd/run/run.go | 14 ++++---------- cmd/view/view.go | 14 ++------------ 4 files changed, 25 insertions(+), 37 deletions(-) diff --git a/cmd/list/list.go b/cmd/list/list.go index 6ab3088a..34a5ecda 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -4,7 +4,6 @@ package list import ( "fmt" - "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" @@ -19,22 +18,12 @@ var ( ) // NewListCommand returns a new command to list available GitHub models. -func NewListCommand() *cobra.Command { +func NewListCommand(client *azuremodels.Client) *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "List available models", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - out := terminal.Out() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } - - client := azuremodels.NewClient(token) ctx := cmd.Context() models, err := client.ListModels(ctx) @@ -47,6 +36,8 @@ func NewListCommand() *cobra.Command { models = filterToChatModels(models) ux.SortModels(models) + terminal := term.FromEnv() + out := terminal.Out() isTTY := terminal.IsTerminalOutput() if isTTY { diff --git a/cmd/root.go b/cmd/root.go index 3e8caf52..7e26e7f3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,9 +4,13 @@ package cmd import ( "strings" + "github.com/cli/go-gh/v2/pkg/auth" + "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/cmd/list" "github.com/github/gh-models/cmd/run" "github.com/github/gh-models/cmd/view" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" ) @@ -17,9 +21,18 @@ func NewRootCommand() *cobra.Command { Short: "GitHub Models extension", } - cmd.AddCommand(list.NewListCommand()) - cmd.AddCommand(run.NewRunCommand()) - cmd.AddCommand(view.NewViewCommand()) + token, _ := auth.TokenForHost("github.com") + if token == "" { + terminal := term.FromEnv() + util.WriteToOut(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + return nil + } + + client := azuremodels.NewClient(token) + + cmd.AddCommand(list.NewListCommand(client)) + cmd.AddCommand(run.NewRunCommand(client)) + cmd.AddCommand(view.NewViewCommand(client)) // Cobra doesn't have a way to specify a two word command (ie. "gh models"), so set a custom usage template // with `gh`` in it. Cobra will use this template for the root and all child commands. diff --git a/cmd/run/run.go b/cmd/run/run.go index bc12d038..a67a5a22 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -14,7 +14,6 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/briandowns/spinner" - "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/sse" @@ -190,13 +189,13 @@ func isPipe(r io.Reader) bool { } // NewRunCommand returns a new gh command for running a model. -func NewRunCommand() *cobra.Command { +func NewRunCommand(client *azuremodels.Client) *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", Short: "Run inference with the specified model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - cmdHandler := newRunCommandHandler(cmd, args) + cmdHandler := newRunCommandHandler(cmd, client, args) if cmdHandler == nil { return nil } @@ -374,21 +373,16 @@ type runCommandHandler struct { args []string } -func newRunCommandHandler(cmd *cobra.Command, args []string) *runCommandHandler { +func newRunCommandHandler(cmd *cobra.Command, client *azuremodels.Client, args []string) *runCommandHandler { terminal := term.FromEnv() out := terminal.Out() - token, _ := auth.TokenForHost("github.com") - if token == "" { - util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } return &runCommandHandler{ ctx: cmd.Context(), terminal: terminal, out: out, args: args, errOut: terminal.ErrOut(), - client: azuremodels.NewClient(token), + client: client, } } diff --git a/cmd/view/view.go b/cmd/view/view.go index 777281d7..c721d531 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -5,30 +5,19 @@ import ( "fmt" "github.com/AlecAivazis/survey/v2" - "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" - "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" ) // NewViewCommand returns a new command to view details about a model. -func NewViewCommand() *cobra.Command { +func NewViewCommand(client *azuremodels.Client) *cobra.Command { cmd := &cobra.Command{ Use: "view [model]", Short: "View details about a model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - util.WriteToOut(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } - - client := azuremodels.NewClient(token) ctx := cmd.Context() models, err := client.ListModels(ctx) @@ -73,6 +62,7 @@ func NewViewCommand() *cobra.Command { return err } + terminal := term.FromEnv() modelPrinter := newModelPrinter(modelSummary, modelDetails, terminal) err = modelPrinter.render() From dd201081a19f977815de247cca8f097390c2bfed Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 11:17:43 -0500 Subject: [PATCH 02/28] Rename Client to AzureClient --- cmd/list/list.go | 2 +- cmd/root.go | 2 +- cmd/run/run.go | 6 +++--- cmd/view/view.go | 2 +- .../azuremodels/{client.go => azure_client.go} | 16 ++++++++-------- 5 files changed, 14 insertions(+), 14 deletions(-) rename internal/azuremodels/{client.go => azure_client.go} (91%) diff --git a/cmd/list/list.go b/cmd/list/list.go index 34a5ecda..cc555d35 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -18,7 +18,7 @@ var ( ) // NewListCommand returns a new command to list available GitHub models. -func NewListCommand(client *azuremodels.Client) *cobra.Command { +func NewListCommand(client *azuremodels.AzureClient) *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "List available models", diff --git a/cmd/root.go b/cmd/root.go index 7e26e7f3..451774c7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -28,7 +28,7 @@ func NewRootCommand() *cobra.Command { return nil } - client := azuremodels.NewClient(token) + client := azuremodels.NewAzureClient(token) cmd.AddCommand(list.NewListCommand(client)) cmd.AddCommand(run.NewRunCommand(client)) diff --git a/cmd/run/run.go b/cmd/run/run.go index a67a5a22..446444b5 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -189,7 +189,7 @@ func isPipe(r io.Reader) bool { } // NewRunCommand returns a new gh command for running a model. -func NewRunCommand(client *azuremodels.Client) *cobra.Command { +func NewRunCommand(client *azuremodels.AzureClient) *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", Short: "Run inference with the specified model", @@ -369,11 +369,11 @@ type runCommandHandler struct { terminal term.Term out io.Writer errOut io.Writer - client *azuremodels.Client + client *azuremodels.AzureClient args []string } -func newRunCommandHandler(cmd *cobra.Command, client *azuremodels.Client, args []string) *runCommandHandler { +func newRunCommandHandler(cmd *cobra.Command, client *azuremodels.AzureClient, args []string) *runCommandHandler { terminal := term.FromEnv() out := terminal.Out() return &runCommandHandler{ diff --git a/cmd/view/view.go b/cmd/view/view.go index c721d531..4d579e9d 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -12,7 +12,7 @@ import ( ) // NewViewCommand returns a new command to view details about a model. -func NewViewCommand(client *azuremodels.Client) *cobra.Command { +func NewViewCommand(client *azuremodels.AzureClient) *cobra.Command { cmd := &cobra.Command{ Use: "view [model]", Short: "View details about a model", diff --git a/internal/azuremodels/client.go b/internal/azuremodels/azure_client.go similarity index 91% rename from internal/azuremodels/client.go rename to internal/azuremodels/azure_client.go index a4b60d39..82ce619c 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/azure_client.go @@ -17,8 +17,8 @@ import ( "golang.org/x/text/language/display" ) -// Client provides a client for interacting with the Azure models API. -type Client struct { +// AzureClient provides a client for interacting with the Azure models API. +type AzureClient struct { client *http.Client token string } @@ -30,16 +30,16 @@ const ( ) // NewClient returns a new client using the given auth token. -func NewClient(authToken string) *Client { +func NewAzureClient(authToken string) *AzureClient { httpClient, _ := api.DefaultHTTPClient() - return &Client{ + return &AzureClient{ client: httpClient, token: authToken, } } // GetChatCompletionStream returns a stream of chat completions for the given request. -func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { +func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { // Check if the model name is `o1-mini` or `o1-preview` if req.Model == "o1-mini" || req.Model == "o1-preview" { req.Stream = false @@ -93,7 +93,7 @@ func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletion } // GetModelDetails returns the details of the specified model in a prticular registry. -func (c *Client) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { +func (c *AzureClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { @@ -171,7 +171,7 @@ func lowercaseStrings(input []string) []string { } // ListModels returns a list of available models. -func (c *Client) ListModels(ctx context.Context) ([]*ModelSummary, error) { +func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { body := bytes.NewReader([]byte(` { "filters": [ @@ -233,7 +233,7 @@ func (c *Client) ListModels(ctx context.Context) ([]*ModelSummary, error) { return models, nil } -func (c *Client) handleHTTPError(resp *http.Response) error { +func (c *AzureClient) handleHTTPError(resp *http.Response) error { sb := strings.Builder{} var err error From c1bb4920aa7ff79815548cb8abb99f344971d1c7 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 11:21:11 -0500 Subject: [PATCH 03/28] Add Client interface --- cmd/list/list.go | 2 +- cmd/run/run.go | 6 +++--- cmd/view/view.go | 2 +- internal/azuremodels/client.go | 9 +++++++++ 4 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 internal/azuremodels/client.go diff --git a/cmd/list/list.go b/cmd/list/list.go index cc555d35..b1d12224 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -18,7 +18,7 @@ var ( ) // NewListCommand returns a new command to list available GitHub models. -func NewListCommand(client *azuremodels.AzureClient) *cobra.Command { +func NewListCommand(client azuremodels.Client) *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "List available models", diff --git a/cmd/run/run.go b/cmd/run/run.go index 446444b5..ee7d855e 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -189,7 +189,7 @@ func isPipe(r io.Reader) bool { } // NewRunCommand returns a new gh command for running a model. -func NewRunCommand(client *azuremodels.AzureClient) *cobra.Command { +func NewRunCommand(client azuremodels.Client) *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", Short: "Run inference with the specified model", @@ -369,11 +369,11 @@ type runCommandHandler struct { terminal term.Term out io.Writer errOut io.Writer - client *azuremodels.AzureClient + client azuremodels.Client args []string } -func newRunCommandHandler(cmd *cobra.Command, client *azuremodels.AzureClient, args []string) *runCommandHandler { +func newRunCommandHandler(cmd *cobra.Command, client azuremodels.Client, args []string) *runCommandHandler { terminal := term.FromEnv() out := terminal.Out() return &runCommandHandler{ diff --git a/cmd/view/view.go b/cmd/view/view.go index 4d579e9d..3d27f3a8 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -12,7 +12,7 @@ import ( ) // NewViewCommand returns a new command to view details about a model. -func NewViewCommand(client *azuremodels.AzureClient) *cobra.Command { +func NewViewCommand(client azuremodels.Client) *cobra.Command { cmd := &cobra.Command{ Use: "view [model]", Short: "View details about a model", diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go new file mode 100644 index 00000000..37a981b0 --- /dev/null +++ b/internal/azuremodels/client.go @@ -0,0 +1,9 @@ +package azuremodels + +import "context" + +type Client interface { + GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + GetModelDetails(context.Context, string, string, string) (*ModelDetails, error) + ListModels(context.Context) ([]*ModelSummary, error) +} From ef468eb2d0286fc270498d58ccbc96ba24cb0d57 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:06:51 -0500 Subject: [PATCH 04/28] Add command.Config type --- cmd/list/list.go | 15 +++----- cmd/root.go | 14 ++++--- cmd/run/run.go | 80 +++++++++++++++++---------------------- cmd/view/model_printer.go | 9 ++--- cmd/view/view.go | 9 ++--- pkg/command/config.go | 22 +++++++++++ 6 files changed, 79 insertions(+), 70 deletions(-) create mode 100644 pkg/command/config.go diff --git a/cmd/list/list.go b/cmd/list/list.go index b1d12224..4d2e1e9c 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -5,9 +5,9 @@ import ( "fmt" "github.com/cli/go-gh/v2/pkg/tableprinter" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" "github.com/mgutz/ansi" "github.com/spf13/cobra" @@ -18,14 +18,14 @@ var ( ) // NewListCommand returns a new command to list available GitHub models. -func NewListCommand(client azuremodels.Client) *cobra.Command { +func NewListCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "List available models", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - + client := cfg.Client models, err := client.ListModels(ctx) if err != nil { return err @@ -36,18 +36,15 @@ func NewListCommand(client azuremodels.Client) *cobra.Command { models = filterToChatModels(models) ux.SortModels(models) - terminal := term.FromEnv() - out := terminal.Out() - isTTY := terminal.IsTerminalOutput() + out := cfg.Out - if isTTY { + if cfg.IsTerminalOutput { util.WriteToOut(out, "\n") util.WriteToOut(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) util.WriteToOut(out, "\n") } - width, _, _ := terminal.Size() - printer := tableprinter.New(out, isTTY, width) + printer := tableprinter.New(out, cfg.IsTerminalOutput, cfg.TerminalWidth) printer.AddHeader([]string{"DISPLAY NAME", "MODEL NAME"}, tableprinter.WithColor(lightGrayUnderline)) printer.EndRow() diff --git a/cmd/root.go b/cmd/root.go index 451774c7..c8387dae 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -10,6 +10,7 @@ import ( "github.com/github/gh-models/cmd/run" "github.com/github/gh-models/cmd/view" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" ) @@ -21,18 +22,21 @@ func NewRootCommand() *cobra.Command { Short: "GitHub Models extension", } + terminal := term.FromEnv() + out := terminal.Out() token, _ := auth.TokenForHost("github.com") if token == "" { - terminal := term.FromEnv() - util.WriteToOut(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") return nil } client := azuremodels.NewAzureClient(token) + width, _, _ := terminal.Size() + cfg := command.NewConfig(out, terminal.ErrOut(), client, terminal.IsTerminalOutput(), width) - cmd.AddCommand(list.NewListCommand(client)) - cmd.AddCommand(run.NewRunCommand(client)) - cmd.AddCommand(view.NewViewCommand(client)) + cmd.AddCommand(list.NewListCommand(cfg)) + cmd.AddCommand(run.NewRunCommand(cfg)) + cmd.AddCommand(view.NewViewCommand(cfg)) // Cobra doesn't have a way to specify a two word command (ie. "gh models"), so set a custom usage template // with `gh`` in it. Cobra will use this template for the root and all child commands. diff --git a/cmd/run/run.go b/cmd/run/run.go index ee7d855e..e609160d 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -14,10 +14,10 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/briandowns/spinner" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -189,13 +189,13 @@ func isPipe(r io.Reader) bool { } // NewRunCommand returns a new gh command for running a model. -func NewRunCommand(client azuremodels.Client) *cobra.Command { +func NewRunCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", Short: "Run inference with the specified model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - cmdHandler := newRunCommandHandler(cmd, client, args) + cmdHandler := newRunCommandHandler(cmd, cfg, args) if cmdHandler == nil { return nil } @@ -306,7 +306,7 @@ func NewRunCommand(client azuremodels.Client) *cobra.Command { mp.UpdateRequest(&req) - sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.errOut)) + sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.cfg.ErrOut)) sp.Start() //nolint:gocritic,revive // TODO defer sp.Stop() @@ -339,7 +339,7 @@ func NewRunCommand(client azuremodels.Client) *cobra.Command { } } - util.WriteToOut(cmdHandler.out, "\n") + util.WriteToOut(cmdHandler.cfg.Out, "\n") _, err = messageBuilder.WriteString("\n") if err != nil { return err @@ -365,29 +365,17 @@ func NewRunCommand(client azuremodels.Client) *cobra.Command { } type runCommandHandler struct { - ctx context.Context - terminal term.Term - out io.Writer - errOut io.Writer - client azuremodels.Client - args []string + ctx context.Context + cfg *command.Config + args []string } -func newRunCommandHandler(cmd *cobra.Command, client azuremodels.Client, args []string) *runCommandHandler { - terminal := term.FromEnv() - out := terminal.Out() - return &runCommandHandler{ - ctx: cmd.Context(), - terminal: terminal, - out: out, - args: args, - errOut: terminal.ErrOut(), - client: client, - } +func newRunCommandHandler(cmd *cobra.Command, cfg *command.Config, args []string) *runCommandHandler { + return &runCommandHandler{ctx: cmd.Context(), cfg: cfg, args: args} } func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) { - models, err := h.client.ListModels(h.ctx) + models, err := h.cfg.Client.ListModels(h.ctx) if err != nil { return nil, err } @@ -450,7 +438,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st } func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) { - resp, err := h.client.GetChatCompletionStream(h.ctx, req) + resp, err := h.cfg.Client.GetChatCompletionStream(h.ctx, req) if err != nil { return nil, err } @@ -458,23 +446,23 @@ func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCo } func (h *runCommandHandler) handleParametersPrompt(conversation Conversation, mp ModelParameters) { - util.WriteToOut(h.out, "Current parameters:\n") + util.WriteToOut(h.cfg.Out, "Current parameters:\n") names := []string{"max-tokens", "temperature", "top-p"} for _, name := range names { - util.WriteToOut(h.out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + util.WriteToOut(h.cfg.Out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) } - util.WriteToOut(h.out, "\n") - util.WriteToOut(h.out, "System Prompt:\n") + util.WriteToOut(h.cfg.Out, "\n") + util.WriteToOut(h.cfg.Out, "System Prompt:\n") if conversation.systemPrompt != "" { - util.WriteToOut(h.out, " "+conversation.systemPrompt+"\n") + util.WriteToOut(h.cfg.Out, " "+conversation.systemPrompt+"\n") } else { - util.WriteToOut(h.out, " \n") + util.WriteToOut(h.cfg.Out, " \n") } } func (h *runCommandHandler) handleResetPrompt(conversation Conversation) { conversation.Reset() - util.WriteToOut(h.out, "Reset chat history\n") + util.WriteToOut(h.cfg.Out, "Reset chat history\n") } func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { @@ -485,34 +473,34 @@ func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { err := mp.SetParameterByName(name, value) if err != nil { - util.WriteToOut(h.out, err.Error()+"\n") + util.WriteToOut(h.cfg.Out, err.Error()+"\n") return } - util.WriteToOut(h.out, "Set "+name+" to "+value+"\n") + util.WriteToOut(h.cfg.Out, "Set "+name+" to "+value+"\n") } else { - util.WriteToOut(h.out, "Invalid /set syntax. Usage: /set \n") + util.WriteToOut(h.cfg.Out, "Invalid /set syntax. Usage: /set \n") } } func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) Conversation { conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - util.WriteToOut(h.out, "Updated system prompt\n") + util.WriteToOut(h.cfg.Out, "Updated system prompt\n") return conversation } func (h *runCommandHandler) handleHelpPrompt() { - util.WriteToOut(h.out, "Commands:\n") - util.WriteToOut(h.out, " /bye, /exit, /quit - Exit the chat\n") - util.WriteToOut(h.out, " /parameters - Show current model parameters\n") - util.WriteToOut(h.out, " /reset, /clear - Reset chat context\n") - util.WriteToOut(h.out, " /set - Set a model parameter\n") - util.WriteToOut(h.out, " /system-prompt - Set the system prompt\n") - util.WriteToOut(h.out, " /help - Show this help message\n") + util.WriteToOut(h.cfg.Out, "Commands:\n") + util.WriteToOut(h.cfg.Out, " /bye, /exit, /quit - Exit the chat\n") + util.WriteToOut(h.cfg.Out, " /parameters - Show current model parameters\n") + util.WriteToOut(h.cfg.Out, " /reset, /clear - Reset chat context\n") + util.WriteToOut(h.cfg.Out, " /set - Set a model parameter\n") + util.WriteToOut(h.cfg.Out, " /system-prompt - Set the system prompt\n") + util.WriteToOut(h.cfg.Out, " /help - Show this help message\n") } func (h *runCommandHandler) handleUnrecognizedPrompt(prompt string) { - util.WriteToOut(h.out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") + util.WriteToOut(h.cfg.Out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") } func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder) error { @@ -524,18 +512,18 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice if err != nil { return err } - util.WriteToOut(h.out, *content) + util.WriteToOut(h.cfg.Out, *content) } else if choice.Message != nil && choice.Message.Content != nil { content := choice.Message.Content _, err := messageBuilder.WriteString(*content) if err != nil { return err } - util.WriteToOut(h.out, *content) + util.WriteToOut(h.cfg.Out, *content) } // Introduce a small delay in between response tokens to better simulate a conversation - if h.terminal.IsTerminalOutput() { + if h.cfg.IsTerminalOutput { time.Sleep(10 * time.Millisecond) } diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go index 63790f3c..28c885bc 100644 --- a/cmd/view/model_printer.go +++ b/cmd/view/model_printer.go @@ -5,8 +5,8 @@ import ( "github.com/cli/cli/v2/pkg/markdown" "github.com/cli/go-gh/v2/pkg/tableprinter" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" "github.com/mgutz/ansi" ) @@ -21,10 +21,9 @@ type modelPrinter struct { terminalWidth int } -func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, terminal term.Term) modelPrinter { - width, _, _ := terminal.Size() - printer := tableprinter.New(terminal.Out(), terminal.IsTerminalOutput(), width) - return modelPrinter{modelSummary: summary, modelDetails: details, printer: printer, terminalWidth: width} +func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, cfg *command.Config) modelPrinter { + printer := tableprinter.New(cfg.Out, cfg.IsTerminalOutput, cfg.TerminalWidth) + return modelPrinter{modelSummary: summary, modelDetails: details, printer: printer, terminalWidth: cfg.TerminalWidth} } func (p *modelPrinter) render() error { diff --git a/cmd/view/view.go b/cmd/view/view.go index 3d27f3a8..37e34e03 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -5,21 +5,21 @@ import ( "fmt" "github.com/AlecAivazis/survey/v2" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/command" "github.com/spf13/cobra" ) // NewViewCommand returns a new command to view details about a model. -func NewViewCommand(client azuremodels.Client) *cobra.Command { +func NewViewCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "view [model]", Short: "View details about a model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - + client := cfg.Client models, err := client.ListModels(ctx) if err != nil { return err @@ -62,8 +62,7 @@ func NewViewCommand(client azuremodels.Client) *cobra.Command { return err } - terminal := term.FromEnv() - modelPrinter := newModelPrinter(modelSummary, modelDetails, terminal) + modelPrinter := newModelPrinter(modelSummary, modelDetails, cfg) err = modelPrinter.render() if err != nil { diff --git a/pkg/command/config.go b/pkg/command/config.go new file mode 100644 index 00000000..eeb20d74 --- /dev/null +++ b/pkg/command/config.go @@ -0,0 +1,22 @@ +// Package command provides shared configuration for sub-commands in the gh-models extension. +package command + +import ( + "io" + + "github.com/github/gh-models/internal/azuremodels" +) + +// Config represents configurable settings for a command. +type Config struct { + Out io.Writer + ErrOut io.Writer + Client azuremodels.Client + IsTerminalOutput bool + TerminalWidth int +} + +// NewConfig returns a new command configuration. +func NewConfig(out io.Writer, errOut io.Writer, client azuremodels.Client, isTerminalOutput bool, width int) *Config { + return &Config{Out: out, ErrOut: errOut, Client: client, IsTerminalOutput: isTerminalOutput, TerminalWidth: width} +} From ec0883b48058c20c5e5ab45dd3d023e8a3950e0c Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:07:40 -0500 Subject: [PATCH 05/28] Docs --- internal/azuremodels/azure_client.go | 4 ++-- internal/azuremodels/client.go | 4 ++++ pkg/command/config.go | 13 +++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/internal/azuremodels/azure_client.go b/internal/azuremodels/azure_client.go index 82ce619c..d5f77e47 100644 --- a/internal/azuremodels/azure_client.go +++ b/internal/azuremodels/azure_client.go @@ -38,7 +38,7 @@ func NewAzureClient(authToken string) *AzureClient { } } -// GetChatCompletionStream returns a stream of chat completions for the given request. +// GetChatCompletionStream returns a stream of chat completions using the given options. func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { // Check if the model name is `o1-mini` or `o1-preview` if req.Model == "o1-mini" || req.Model == "o1-preview" { @@ -92,7 +92,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl return &chatCompletionResponse, nil } -// GetModelDetails returns the details of the specified model in a prticular registry. +// GetModelDetails returns the details of the specified model in a particular registry. func (c *AzureClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go index 37a981b0..4e1b83f9 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/client.go @@ -2,8 +2,12 @@ package azuremodels import "context" +// Client represents a client for interacting with an API about models. type Client interface { + // GetChatCompletionStream returns a stream of chat completions using the given options. GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + // GetModelDetails returns the details of the specified model in a particular registry. GetModelDetails(context.Context, string, string, string) (*ModelDetails, error) + // ListModels returns a list of available models. ListModels(context.Context) ([]*ModelSummary, error) } diff --git a/pkg/command/config.go b/pkg/command/config.go index eeb20d74..8ff505a4 100644 --- a/pkg/command/config.go +++ b/pkg/command/config.go @@ -9,11 +9,16 @@ import ( // Config represents configurable settings for a command. type Config struct { - Out io.Writer - ErrOut io.Writer - Client azuremodels.Client + // Out is where standard output is written. + Out io.Writer + // ErrOut is where error output is written. + ErrOut io.Writer + // Client is the client for interacting with the models service. + Client azuremodels.Client + // IsTerminalOutput is true if the output should be formatted for a terminal. IsTerminalOutput bool - TerminalWidth int + // TerminalWidth is the width of the terminal. + TerminalWidth int } // NewConfig returns a new command configuration. From 8076365cd0e1e51767d6afeb2db105dca9de25a8 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:12:51 -0500 Subject: [PATCH 06/28] Add MockClient for testing --- internal/azuremodels/mock_client.go | 43 +++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 internal/azuremodels/mock_client.go diff --git a/internal/azuremodels/mock_client.go b/internal/azuremodels/mock_client.go new file mode 100644 index 00000000..e0b1ba8e --- /dev/null +++ b/internal/azuremodels/mock_client.go @@ -0,0 +1,43 @@ +package azuremodels + +import ( + "context" +) + +// MockClient provides a client for interacting with the Azure models API in tests. +type MockClient struct { + MockGetChatCompletionStream func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + MockGetModelDetails func(context.Context, string, string, string) (*ModelDetails, error) + MockListModels func(context.Context) ([]*ModelSummary, error) +} + +// NewMockClient returns a new mock client. +func NewMockClient() *MockClient { + return &MockClient{ + MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) { + return nil, nil + }, + MockGetModelDetails: func(context.Context, string, string, string) (*ModelDetails, error) { + return nil, nil + }, + MockListModels: func(context.Context) ([]*ModelSummary, error) { + return nil, nil + }, + } +} + +// GetChatCompletionStream calls the mocked function for getting a stream of chat completions for the given request. + +func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { + return c.MockGetChatCompletionStream(ctx, opt) +} + +// GetModelDetails calls the mocked function for getting the details of the specified model in a particular registry. +func (c *MockClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + return c.MockGetModelDetails(ctx, registry, modelName, version) +} + +// ListModels calls the mocked function for getting a list of available models. +func (c *MockClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + return c.MockListModels(ctx) +} From 0e91a6ecb62c72d92960e2a736c4b8561731670f Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:13:10 -0500 Subject: [PATCH 07/28] Add a basic test for the list command --- cmd/list/list_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++ go.mod | 3 +++ 2 files changed, 49 insertions(+) create mode 100644 cmd/list/list_test.go diff --git a/cmd/list/list_test.go b/cmd/list/list_test.go new file mode 100644 index 00000000..60ceabed --- /dev/null +++ b/cmd/list/list_test.go @@ -0,0 +1,46 @@ +package list + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/stretchr/testify/require" +) + +func TestList(t *testing.T) { + t.Run("NewListCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + listCmd := NewListCommand(cfg) + + _, err := listCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + output := buf.String() + require.Contains(t, output, "Showing 1 available chat models") + require.Contains(t, output, "DISPLAY NAME") + require.Contains(t, output, "MODEL NAME") + require.Contains(t, output, modelSummary.FriendlyName) + require.Contains(t, output, modelSummary.Name) + }) +} diff --git a/go.mod b/go.mod index 27a21d5a..9b73ace2 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.9.0 golang.org/x/text v0.18.0 ) @@ -24,6 +25,7 @@ require ( github.com/charmbracelet/x/exp/term v0.0.0-20240425164147-ba2a9512b05f // indirect github.com/cli/safeexec v1.0.1 // indirect github.com/cli/shurcooL-graphql v0.0.4 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dlclark/regexp2 v1.4.0 // indirect github.com/fatih/color v1.16.0 // indirect github.com/gorilla/css v1.0.0 // indirect @@ -39,6 +41,7 @@ require ( github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect github.com/yuin/goldmark v1.5.4 // indirect From 3122947fdd208f31cbe6e7932144c44ab3de61c4 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:17:22 -0500 Subject: [PATCH 08/28] Probably don't need XL for lint --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index cc5ebdce..b2a4c7fc 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,7 +12,7 @@ jobs: lint: strategy: fail-fast: false - runs-on: ubuntu-latest-xl + runs-on: ubuntu-latest env: GOPROXY: https://goproxy.githubapp.com/mod,https://proxy.golang.org/,direct GOPRIVATE: "" From 67818636efce5b25e8fa1ec22168ac906ed1cd66 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:17:49 -0500 Subject: [PATCH 09/28] Add test workflow https://docs.github.com/en/actions/use-cases-and-examples/building-and-testing/building-and-testing-go --- .github/workflows/test.yml | 42 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..e748bf90 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,42 @@ +name: "Build and test" + +on: + pull_request: + push: + workflow_dispatch: + merge_group: + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + env: + GOPROXY: https://goproxy.githubapp.com/mod,https://proxy.golang.org/,direct + GOPRIVATE: "" + GONOPROXY: "" + GONOSUMDB: github.com/github/* + steps: + - uses: actions/checkout@v4 + - name: Configure Go private module access + run: | + echo "machine goproxy.githubapp.com login nobody password ${{ secrets.GOPROXY_TOKEN }}" >> $HOME/.netrc + - uses: actions/setup-go@v5 + with: + go-version: ${{ vars.GOVERSION }} + check-latest: true + - name: Verify go.sum is up to date + run: | + go mod tidy + git diff --exit-code go.sum + if [ $? -ne 0 ]; then + echo "Error: go.sum has changed, please run `go mod tidy` and commit the result" + exit 1 + fi + + - name: Build program + run: go build -v ./... + + - name: Run tests + run: go test -race ./... From f247ca25450806dcae6d9404024e60adde8bc825 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:19:27 -0500 Subject: [PATCH 10/28] Cancel existing runs when starting a new one --- .github/workflows/lint.yml | 4 ++++ .github/workflows/test.yml | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b2a4c7fc..c0aa114c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -8,6 +8,10 @@ on: permissions: contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: lint: strategy: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e748bf90..93954bd8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,10 @@ on: permissions: contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: build: runs-on: ubuntu-latest From f1d30e08951aca822e93ff184067217a0b0312b4 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:20:00 -0500 Subject: [PATCH 11/28] Don't need dual push + PR triggers --- .github/workflows/test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 93954bd8..b4f8303f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,7 +2,6 @@ name: "Build and test" on: pull_request: - push: workflow_dispatch: merge_group: From 644fb0dc896524a346b7d80cfa59a7cdaf17405c Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:25:52 -0500 Subject: [PATCH 12/28] Linter fixes --- internal/azuremodels/azure_client.go | 2 +- internal/azuremodels/mock_client.go | 3 +-- pkg/command/config.go | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/internal/azuremodels/azure_client.go b/internal/azuremodels/azure_client.go index d5f77e47..7d58da70 100644 --- a/internal/azuremodels/azure_client.go +++ b/internal/azuremodels/azure_client.go @@ -29,7 +29,7 @@ const ( prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models" ) -// NewClient returns a new client using the given auth token. +// NewAzureClient returns a new Azure client using the given auth token. func NewAzureClient(authToken string) *AzureClient { httpClient, _ := api.DefaultHTTPClient() return &AzureClient{ diff --git a/internal/azuremodels/mock_client.go b/internal/azuremodels/mock_client.go index e0b1ba8e..24432b2e 100644 --- a/internal/azuremodels/mock_client.go +++ b/internal/azuremodels/mock_client.go @@ -11,7 +11,7 @@ type MockClient struct { MockListModels func(context.Context) ([]*ModelSummary, error) } -// NewMockClient returns a new mock client. +// NewMockClient returns a new mock client for stubbing out interactions with the models API. func NewMockClient() *MockClient { return &MockClient{ MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) { @@ -27,7 +27,6 @@ func NewMockClient() *MockClient { } // GetChatCompletionStream calls the mocked function for getting a stream of chat completions for the given request. - func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { return c.MockGetChatCompletionStream(ctx, opt) } diff --git a/pkg/command/config.go b/pkg/command/config.go index 8ff505a4..ba8b8a51 100644 --- a/pkg/command/config.go +++ b/pkg/command/config.go @@ -22,6 +22,6 @@ type Config struct { } // NewConfig returns a new command configuration. -func NewConfig(out io.Writer, errOut io.Writer, client azuremodels.Client, isTerminalOutput bool, width int) *Config { +func NewConfig(out, errOut io.Writer, client azuremodels.Client, isTerminalOutput bool, width int) *Config { return &Config{Out: out, ErrOut: errOut, Client: client, IsTerminalOutput: isTerminalOutput, TerminalWidth: width} } From d236e8f5e929ed4907bde14ca3a93a052101c5fd Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:27:54 -0500 Subject: [PATCH 13/28] Add NewConfigWithTerminal --- cmd/root.go | 3 +-- pkg/command/config.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index c8387dae..b25a47a5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -31,8 +31,7 @@ func NewRootCommand() *cobra.Command { } client := azuremodels.NewAzureClient(token) - width, _, _ := terminal.Size() - cfg := command.NewConfig(out, terminal.ErrOut(), client, terminal.IsTerminalOutput(), width) + cfg := command.NewConfigWithTerminal(terminal, client) cmd.AddCommand(list.NewListCommand(cfg)) cmd.AddCommand(run.NewRunCommand(cfg)) diff --git a/pkg/command/config.go b/pkg/command/config.go index ba8b8a51..00cbe61f 100644 --- a/pkg/command/config.go +++ b/pkg/command/config.go @@ -4,6 +4,7 @@ package command import ( "io" + "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" ) @@ -25,3 +26,15 @@ type Config struct { func NewConfig(out, errOut io.Writer, client azuremodels.Client, isTerminalOutput bool, width int) *Config { return &Config{Out: out, ErrOut: errOut, Client: client, IsTerminalOutput: isTerminalOutput, TerminalWidth: width} } + +// NewConfigWithTerminal returns a new command configuration using the given terminal. +func NewConfigWithTerminal(terminal term.Term, client azuremodels.Client) *Config { + width, _, _ := terminal.Size() + return &Config{ + Out: terminal.Out(), + ErrOut: terminal.ErrOut(), + Client: client, + IsTerminalOutput: terminal.IsTerminalOutput(), + TerminalWidth: width, + } +} From 9a307d64b4d3ecfdefe3b75ad90968c27560cf4c Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:33:35 -0500 Subject: [PATCH 14/28] Fix label for model name in view output --- cmd/view/model_printer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go index 28c885bc..22451952 100644 --- a/cmd/view/model_printer.go +++ b/cmd/view/model_printer.go @@ -30,7 +30,7 @@ func (p *modelPrinter) render() error { modelSummary := p.modelSummary if modelSummary != nil { p.printLabelledLine("Display name:", modelSummary.FriendlyName) - p.printLabelledLine("Summary name:", modelSummary.Name) + p.printLabelledLine("Model name:", modelSummary.Name) p.printLabelledLine("Publisher:", modelSummary.Publisher) p.printLabelledLine("Summary:", modelSummary.Summary) } From 5d95d27f7749b4bc7a0fb93abe5e8dc8b6014253 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:36:50 -0500 Subject: [PATCH 15/28] Add basic test for view command --- cmd/view/view_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 cmd/view/view_test.go diff --git a/cmd/view/view_test.go b/cmd/view/view_test.go new file mode 100644 index 00000000..8348b31e --- /dev/null +++ b/cmd/view/view_test.go @@ -0,0 +1,89 @@ +package view + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/stretchr/testify/require" +) + +func TestView(t *testing.T) { + t.Run("NewViewCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + getModelDetailsCallCount := 0 + modelDetails := &azuremodels.ModelDetails{ + Description: "Fake description", + Evaluation: "Fake evaluation", + License: "MIT", + LicenseDescription: "This is a test license", + Tags: []string{"tag1", "tag2"}, + SupportedInputModalities: []string{"text", "carrier-pigeon"}, + SupportedOutputModalities: []string{"underwater-signals"}, + SupportedLanguages: []string{"English", "Spanish"}, + MaxOutputTokens: 123, + MaxInputTokens: 456, + RateLimitTier: "mediumish", + } + client.MockGetModelDetails = func(ctx context.Context, registryName, modelName, version string) (*azuremodels.ModelDetails, error) { + getModelDetailsCallCount++ + return modelDetails, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + viewCmd := NewViewCommand(cfg) + viewCmd.SetArgs([]string{modelSummary.Name}) + + _, err := viewCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + require.Equal(t, 1, getModelDetailsCallCount) + output := buf.String() + require.Contains(t, output, "Display name:") + require.Contains(t, output, modelSummary.FriendlyName) + require.Contains(t, output, "Model name:") + require.Contains(t, output, modelSummary.Name) + require.Contains(t, output, "Publisher:") + require.Contains(t, output, modelSummary.Publisher) + require.Contains(t, output, "Summary:") + require.Contains(t, output, modelSummary.Summary) + require.Contains(t, output, "Context:") + require.Contains(t, output, "up to 456 input tokens and 123 output tokens") + require.Contains(t, output, "Rate limit tier:") + require.Contains(t, output, "mediumish") + require.Contains(t, output, "Tags:") + require.Contains(t, output, "tag1, tag2") + require.Contains(t, output, "Supported input types:") + require.Contains(t, output, "text, carrier-pigeon") + require.Contains(t, output, "Supported output types:") + require.Contains(t, output, "underwater-signals") + require.Contains(t, output, "Supported languages:") + require.Contains(t, output, "English, Spanish") + require.Contains(t, output, "License:") + require.Contains(t, output, modelDetails.License) + require.Contains(t, output, "License description:") + require.Contains(t, output, modelDetails.LicenseDescription) + require.Contains(t, output, "Description:") + require.Contains(t, output, modelDetails.Description) + require.Contains(t, output, "Evaluation:") + require.Contains(t, output, modelDetails.Evaluation) + }) +} From 13dc7956dcbed52d934fa8fda7285d9893554bf1 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 12:40:18 -0500 Subject: [PATCH 16/28] Return 'not implemented' error by default for mocks --- internal/azuremodels/mock_client.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/azuremodels/mock_client.go b/internal/azuremodels/mock_client.go index 24432b2e..c15cfb6d 100644 --- a/internal/azuremodels/mock_client.go +++ b/internal/azuremodels/mock_client.go @@ -2,6 +2,7 @@ package azuremodels import ( "context" + "errors" ) // MockClient provides a client for interacting with the Azure models API in tests. @@ -15,13 +16,13 @@ type MockClient struct { func NewMockClient() *MockClient { return &MockClient{ MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) { - return nil, nil + return nil, errors.New("GetChatCompletionStream not implemented") }, MockGetModelDetails: func(context.Context, string, string, string) (*ModelDetails, error) { - return nil, nil + return nil, errors.New("GetModelDetails not implemented") }, MockListModels: func(context.Context) ([]*ModelSummary, error) { - return nil, nil + return nil, errors.New("ListModels not implemented") }, } } From 57b6183a9d91d3eb42ccb6d47e2faaf70e8b1b26 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 13:30:27 -0500 Subject: [PATCH 17/28] Export ChatChoiceMessage type from azuremodels --- internal/azuremodels/types.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index 98138faf..d8d5a52d 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -36,7 +36,8 @@ type ChatCompletionOptions struct { TopP *float64 `json:"top_p,omitempty"` } -type chatChoiceMessage struct { +// ChatChoiceMessage is a message from a choice in a chat conversation. +type ChatChoiceMessage struct { Content *string `json:"content,omitempty"` Role *string `json:"role,omitempty"` } @@ -51,7 +52,7 @@ type ChatChoice struct { Delta *chatChoiceDelta `json:"delta,omitempty"` FinishReason string `json:"finish_reason"` Index int32 `json:"index"` - Message *chatChoiceMessage `json:"message,omitempty"` + Message *ChatChoiceMessage `json:"message,omitempty"` } // ChatCompletion represents a chat completion. From a804f546eec41ada90116fc6b3e23c6c86653ae7 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 13:30:33 -0500 Subject: [PATCH 18/28] Add basic run command test --- cmd/run/run_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 cmd/run/run_test.go diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go new file mode 100644 index 00000000..92755745 --- /dev/null +++ b/cmd/run/run_test.go @@ -0,0 +1,62 @@ +package run + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestRun(t *testing.T) { + t.Run("NewRunCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + fakeMessageFromModel := "yes hello this is dog" + chatChoice := azuremodels.ChatChoice{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(fakeMessageFromModel), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + } + chatCompletion := azuremodels.ChatCompletion{Choices: []azuremodels.ChatChoice{chatChoice}} + chatResp := &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + } + getChatCompletionCallCount := 0 + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + getChatCompletionCallCount++ + return chatResp, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{modelSummary.Name, "this is my prompt"}) + + _, err := runCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + require.Equal(t, 1, getChatCompletionCallCount) + output := buf.String() + require.Contains(t, output, fakeMessageFromModel) + }) +} From bd5911372a5455f1aea45c4e6b4ed0333490bcac Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 13:33:56 -0500 Subject: [PATCH 19/28] Don't need to rerun CI if a label changes, for example synchronize = pushed to head, see https://docs.github.com/webhooks/webhook-events-and-payloads?actionType=synchronize#pull_request --- .github/workflows/lint.yml | 1 + .github/workflows/test.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c0aa114c..8becfc2e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -2,6 +2,7 @@ name: "go-linter" on: pull_request: + types: [opened, synchronize, reopened] merge_group: workflow_dispatch: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b4f8303f..1df3b021 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,6 +2,7 @@ name: "Build and test" on: pull_request: + types: [opened, synchronize, reopened] workflow_dispatch: merge_group: From f93fc5e66934a4ff94b2475ddae04ec4aeba376a Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 13:53:08 -0500 Subject: [PATCH 20/28] Run CI when pushing to main branch --- .github/workflows/lint.yml | 3 +++ .github/workflows/test.yml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8becfc2e..1659a21c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,6 +5,9 @@ on: types: [opened, synchronize, reopened] merge_group: workflow_dispatch: + push: + branches: + - 'main' permissions: contents: read diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1df3b021..08bd8ea5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,6 +5,9 @@ on: types: [opened, synchronize, reopened] workflow_dispatch: merge_group: + push: + branches: + - 'main' permissions: contents: read From 64bd851236e5f19c09419ce9345f5f8ba54831d5 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 14:00:56 -0500 Subject: [PATCH 21/28] Output code coverage percentages --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 08bd8ea5..49bbaed8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -46,4 +46,4 @@ jobs: run: go build -v ./... - name: Run tests - run: go test -race ./... + run: go test -race -cover ./... From d69e1ae63935de1d3290d705f27aeb34ea11bf2c Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 14:08:39 -0500 Subject: [PATCH 22/28] Add NewTablePrinter --- cmd/list/list.go | 10 ++++------ cmd/view/model_printer.go | 8 ++++++-- pkg/command/config.go | 6 ++++++ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/cmd/list/list.go b/cmd/list/list.go index 4d2e1e9c..84bd4b1e 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -36,15 +36,13 @@ func NewListCommand(cfg *command.Config) *cobra.Command { models = filterToChatModels(models) ux.SortModels(models) - out := cfg.Out - if cfg.IsTerminalOutput { - util.WriteToOut(out, "\n") - util.WriteToOut(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) - util.WriteToOut(out, "\n") + util.WriteToOut(cfg.Out, "\n") + util.WriteToOut(cfg.Out, fmt.Sprintf("Showing %d available chat models\n", len(models))) + util.WriteToOut(cfg.Out, "\n") } - printer := tableprinter.New(out, cfg.IsTerminalOutput, cfg.TerminalWidth) + printer := cfg.NewTablePrinter() printer.AddHeader([]string{"DISPLAY NAME", "MODEL NAME"}, tableprinter.WithColor(lightGrayUnderline)) printer.EndRow() diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go index 22451952..6776c571 100644 --- a/cmd/view/model_printer.go +++ b/cmd/view/model_printer.go @@ -22,8 +22,12 @@ type modelPrinter struct { } func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, cfg *command.Config) modelPrinter { - printer := tableprinter.New(cfg.Out, cfg.IsTerminalOutput, cfg.TerminalWidth) - return modelPrinter{modelSummary: summary, modelDetails: details, printer: printer, terminalWidth: cfg.TerminalWidth} + return modelPrinter{ + modelSummary: summary, + modelDetails: details, + printer: cfg.NewTablePrinter(), + terminalWidth: cfg.TerminalWidth, + } } func (p *modelPrinter) render() error { diff --git a/pkg/command/config.go b/pkg/command/config.go index 00cbe61f..6993178b 100644 --- a/pkg/command/config.go +++ b/pkg/command/config.go @@ -4,6 +4,7 @@ package command import ( "io" + "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" ) @@ -38,3 +39,8 @@ func NewConfigWithTerminal(terminal term.Term, client azuremodels.Client) *Confi TerminalWidth: width, } } + +// NewTablePrinter initializes a table printer with terminal mode and terminal width. +func (c *Config) NewTablePrinter() tableprinter.TablePrinter { + return tableprinter.New(c.Out, c.IsTerminalOutput, c.TerminalWidth) +} From 3738e1f27672c4886c8a466a262e097756c7587e Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 14:11:26 -0500 Subject: [PATCH 23/28] Add Config WriteToOut --- cmd/list/list.go | 7 +++---- cmd/run/run.go | 44 +++++++++++++++++++++---------------------- pkg/command/config.go | 6 ++++++ 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/cmd/list/list.go b/cmd/list/list.go index 84bd4b1e..75437939 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -8,7 +8,6 @@ import ( "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/command" - "github.com/github/gh-models/pkg/util" "github.com/mgutz/ansi" "github.com/spf13/cobra" ) @@ -37,9 +36,9 @@ func NewListCommand(cfg *command.Config) *cobra.Command { ux.SortModels(models) if cfg.IsTerminalOutput { - util.WriteToOut(cfg.Out, "\n") - util.WriteToOut(cfg.Out, fmt.Sprintf("Showing %d available chat models\n", len(models))) - util.WriteToOut(cfg.Out, "\n") + cfg.WriteToOut("\n") + cfg.WriteToOut(fmt.Sprintf("Showing %d available chat models\n", len(models))) + cfg.WriteToOut("\n") } printer := cfg.NewTablePrinter() diff --git a/cmd/run/run.go b/cmd/run/run.go index e609160d..d08ddbfe 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -339,7 +339,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } } - util.WriteToOut(cmdHandler.cfg.Out, "\n") + cmdHandler.cfg.WriteToOut("\n") _, err = messageBuilder.WriteString("\n") if err != nil { return err @@ -446,23 +446,23 @@ func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCo } func (h *runCommandHandler) handleParametersPrompt(conversation Conversation, mp ModelParameters) { - util.WriteToOut(h.cfg.Out, "Current parameters:\n") + h.cfg.WriteToOut("Current parameters:\n") names := []string{"max-tokens", "temperature", "top-p"} for _, name := range names { - util.WriteToOut(h.cfg.Out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + h.cfg.WriteToOut(fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) } - util.WriteToOut(h.cfg.Out, "\n") - util.WriteToOut(h.cfg.Out, "System Prompt:\n") + h.cfg.WriteToOut("\n") + h.cfg.WriteToOut("System Prompt:\n") if conversation.systemPrompt != "" { - util.WriteToOut(h.cfg.Out, " "+conversation.systemPrompt+"\n") + h.cfg.WriteToOut(" " + conversation.systemPrompt + "\n") } else { - util.WriteToOut(h.cfg.Out, " \n") + h.cfg.WriteToOut(" \n") } } func (h *runCommandHandler) handleResetPrompt(conversation Conversation) { conversation.Reset() - util.WriteToOut(h.cfg.Out, "Reset chat history\n") + h.cfg.WriteToOut("Reset chat history\n") } func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { @@ -473,34 +473,34 @@ func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { err := mp.SetParameterByName(name, value) if err != nil { - util.WriteToOut(h.cfg.Out, err.Error()+"\n") + h.cfg.WriteToOut(err.Error() + "\n") return } - util.WriteToOut(h.cfg.Out, "Set "+name+" to "+value+"\n") + h.cfg.WriteToOut("Set " + name + " to " + value + "\n") } else { - util.WriteToOut(h.cfg.Out, "Invalid /set syntax. Usage: /set \n") + h.cfg.WriteToOut("Invalid /set syntax. Usage: /set \n") } } func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) Conversation { conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - util.WriteToOut(h.cfg.Out, "Updated system prompt\n") + h.cfg.WriteToOut("Updated system prompt\n") return conversation } func (h *runCommandHandler) handleHelpPrompt() { - util.WriteToOut(h.cfg.Out, "Commands:\n") - util.WriteToOut(h.cfg.Out, " /bye, /exit, /quit - Exit the chat\n") - util.WriteToOut(h.cfg.Out, " /parameters - Show current model parameters\n") - util.WriteToOut(h.cfg.Out, " /reset, /clear - Reset chat context\n") - util.WriteToOut(h.cfg.Out, " /set - Set a model parameter\n") - util.WriteToOut(h.cfg.Out, " /system-prompt - Set the system prompt\n") - util.WriteToOut(h.cfg.Out, " /help - Show this help message\n") + h.cfg.WriteToOut("Commands:\n") + h.cfg.WriteToOut(" /bye, /exit, /quit - Exit the chat\n") + h.cfg.WriteToOut(" /parameters - Show current model parameters\n") + h.cfg.WriteToOut(" /reset, /clear - Reset chat context\n") + h.cfg.WriteToOut(" /set - Set a model parameter\n") + h.cfg.WriteToOut(" /system-prompt - Set the system prompt\n") + h.cfg.WriteToOut(" /help - Show this help message\n") } func (h *runCommandHandler) handleUnrecognizedPrompt(prompt string) { - util.WriteToOut(h.cfg.Out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") + h.cfg.WriteToOut("Unknown command '" + prompt + "'. See /help for supported commands.\n") } func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder) error { @@ -512,14 +512,14 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice if err != nil { return err } - util.WriteToOut(h.cfg.Out, *content) + h.cfg.WriteToOut(*content) } else if choice.Message != nil && choice.Message.Content != nil { content := choice.Message.Content _, err := messageBuilder.WriteString(*content) if err != nil { return err } - util.WriteToOut(h.cfg.Out, *content) + h.cfg.WriteToOut(*content) } // Introduce a small delay in between response tokens to better simulate a conversation diff --git a/pkg/command/config.go b/pkg/command/config.go index 6993178b..36296b44 100644 --- a/pkg/command/config.go +++ b/pkg/command/config.go @@ -7,6 +7,7 @@ import ( "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/util" ) // Config represents configurable settings for a command. @@ -44,3 +45,8 @@ func NewConfigWithTerminal(terminal term.Term, client azuremodels.Client) *Confi func (c *Config) NewTablePrinter() tableprinter.TablePrinter { return tableprinter.New(c.Out, c.IsTerminalOutput, c.TerminalWidth) } + +// WriteToOut writes a message to the configured stdout writer. +func (c *Config) WriteToOut(message string) { + util.WriteToOut(c.Out, message) +} From 099849adb90d5cb5c021a4b0a763072423699396 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 14:12:48 -0500 Subject: [PATCH 24/28] Add runCommandHandler writeToOut Shorter lines, a little less reaching through. --- cmd/run/run.go | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index d08ddbfe..120441c0 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -339,7 +339,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } } - cmdHandler.cfg.WriteToOut("\n") + cmdHandler.writeToOut("\n") _, err = messageBuilder.WriteString("\n") if err != nil { return err @@ -446,23 +446,23 @@ func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCo } func (h *runCommandHandler) handleParametersPrompt(conversation Conversation, mp ModelParameters) { - h.cfg.WriteToOut("Current parameters:\n") + h.writeToOut("Current parameters:\n") names := []string{"max-tokens", "temperature", "top-p"} for _, name := range names { - h.cfg.WriteToOut(fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + h.writeToOut(fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) } - h.cfg.WriteToOut("\n") - h.cfg.WriteToOut("System Prompt:\n") + h.writeToOut("\n") + h.writeToOut("System Prompt:\n") if conversation.systemPrompt != "" { - h.cfg.WriteToOut(" " + conversation.systemPrompt + "\n") + h.writeToOut(" " + conversation.systemPrompt + "\n") } else { - h.cfg.WriteToOut(" \n") + h.writeToOut(" \n") } } func (h *runCommandHandler) handleResetPrompt(conversation Conversation) { conversation.Reset() - h.cfg.WriteToOut("Reset chat history\n") + h.writeToOut("Reset chat history\n") } func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { @@ -473,34 +473,34 @@ func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { err := mp.SetParameterByName(name, value) if err != nil { - h.cfg.WriteToOut(err.Error() + "\n") + h.writeToOut(err.Error() + "\n") return } - h.cfg.WriteToOut("Set " + name + " to " + value + "\n") + h.writeToOut("Set " + name + " to " + value + "\n") } else { - h.cfg.WriteToOut("Invalid /set syntax. Usage: /set \n") + h.writeToOut("Invalid /set syntax. Usage: /set \n") } } func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) Conversation { conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - h.cfg.WriteToOut("Updated system prompt\n") + h.writeToOut("Updated system prompt\n") return conversation } func (h *runCommandHandler) handleHelpPrompt() { - h.cfg.WriteToOut("Commands:\n") - h.cfg.WriteToOut(" /bye, /exit, /quit - Exit the chat\n") - h.cfg.WriteToOut(" /parameters - Show current model parameters\n") - h.cfg.WriteToOut(" /reset, /clear - Reset chat context\n") - h.cfg.WriteToOut(" /set - Set a model parameter\n") - h.cfg.WriteToOut(" /system-prompt - Set the system prompt\n") - h.cfg.WriteToOut(" /help - Show this help message\n") + h.writeToOut("Commands:\n") + h.writeToOut(" /bye, /exit, /quit - Exit the chat\n") + h.writeToOut(" /parameters - Show current model parameters\n") + h.writeToOut(" /reset, /clear - Reset chat context\n") + h.writeToOut(" /set - Set a model parameter\n") + h.writeToOut(" /system-prompt - Set the system prompt\n") + h.writeToOut(" /help - Show this help message\n") } func (h *runCommandHandler) handleUnrecognizedPrompt(prompt string) { - h.cfg.WriteToOut("Unknown command '" + prompt + "'. See /help for supported commands.\n") + h.writeToOut("Unknown command '" + prompt + "'. See /help for supported commands.\n") } func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder) error { @@ -512,14 +512,14 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice if err != nil { return err } - h.cfg.WriteToOut(*content) + h.writeToOut(*content) } else if choice.Message != nil && choice.Message.Content != nil { content := choice.Message.Content _, err := messageBuilder.WriteString(*content) if err != nil { return err } - h.cfg.WriteToOut(*content) + h.writeToOut(*content) } // Introduce a small delay in between response tokens to better simulate a conversation @@ -529,3 +529,7 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice return nil } + +func (h *runCommandHandler) writeToOut(message string) { + h.cfg.WriteToOut(message) +} From 0c00848761d7e943220362def2f89fc213698c6b Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 14:17:05 -0500 Subject: [PATCH 25/28] Include client prop on runCommandHandler Less reaching-through at call sites. --- cmd/run/run.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index 120441c0..7b542358 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -365,17 +365,18 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } type runCommandHandler struct { - ctx context.Context - cfg *command.Config - args []string + ctx context.Context + cfg *command.Config + client azuremodels.Client + args []string } func newRunCommandHandler(cmd *cobra.Command, cfg *command.Config, args []string) *runCommandHandler { - return &runCommandHandler{ctx: cmd.Context(), cfg: cfg, args: args} + return &runCommandHandler{ctx: cmd.Context(), cfg: cfg, client: cfg.Client, args: args} } func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) { - models, err := h.cfg.Client.ListModels(h.ctx) + models, err := h.client.ListModels(h.ctx) if err != nil { return nil, err } @@ -438,7 +439,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st } func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) { - resp, err := h.cfg.Client.GetChatCompletionStream(h.ctx, req) + resp, err := h.client.GetChatCompletionStream(h.ctx, req) if err != nil { return nil, err } From 3ede78e4ef3aba305517260ddb1a7e221bbe18a0 Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 15:30:49 -0500 Subject: [PATCH 26/28] Do we need private module proxy stuff? Related: https://github.com/github/gh-models/pull/15#discussion_r1797364202 --- .github/workflows/lint.yml | 5 +---- .github/workflows/test.yml | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1659a21c..eed82d4d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,7 +22,7 @@ jobs: fail-fast: false runs-on: ubuntu-latest env: - GOPROXY: https://goproxy.githubapp.com/mod,https://proxy.golang.org/,direct + GOPROXY: https://proxy.golang.org/,direct GOPRIVATE: "" GONOPROXY: "" GONOSUMDB: github.com/github/* @@ -32,9 +32,6 @@ jobs: go-version: ${{ vars.GOVERSION }} check-latest: true - uses: actions/checkout@v4 - - name: Configure Go private module access - run: | - echo "machine goproxy.githubapp.com login nobody password ${{ secrets.GOPROXY_TOKEN }}" >> $HOME/.netrc - name: Lint # This also does checkout, setup-go, and proxy setup. uses: github/go-linter@v1.2.1 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 49bbaed8..3fe22e3a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,15 +20,12 @@ jobs: build: runs-on: ubuntu-latest env: - GOPROXY: https://goproxy.githubapp.com/mod,https://proxy.golang.org/,direct + GOPROXY: https://proxy.golang.org/,direct GOPRIVATE: "" GONOPROXY: "" GONOSUMDB: github.com/github/* steps: - uses: actions/checkout@v4 - - name: Configure Go private module access - run: | - echo "machine goproxy.githubapp.com login nobody password ${{ secrets.GOPROXY_TOKEN }}" >> $HOME/.netrc - uses: actions/setup-go@v5 with: go-version: ${{ vars.GOVERSION }} From 1b48b0f993a0f0ced56dec1b7cff1b3a47b7711b Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 15:50:34 -0500 Subject: [PATCH 27/28] Don't panic when trying to run commands when not auth'd e.g., ./gh-models run gpt-4o-mini "hello" No GitHub token found. Please run 'gh auth login' to authenticate. panic: runtime error: invalid memory address or nil pointer dereference [signal SIGSEGV: segmentation violation code=0x1 addr=0x0 pc=0xba48c6a] goroutine 1 [running]: github.com/github/gh-models/internal/azuremodels.(*AzureClient).ListModels(0x0, {0xd098300, 0xd5ca4a0}) gh-models/internal/azuremodels/azure_client.go:194 +0x20a github.com/github/gh-models/cmd/run.(*runCommandHandler).loadModels(0x4?) gh-models/cmd/run/run.go:382 +0x2a github.com/github/gh-models/cmd/run.NewRunCommand.func1(0xc00021a308, {0xc0002d82e0, 0x2, 0x2}) gh-models/cmd/run/run.go:203 +0x139 github.com/spf13/cobra.(*Command).execute(0xc00021a308, {0xc0002d82a0, 0x2, 0x2}) go/pkg/mod/github.com/spf13/cobra@v1.8.1/command.go:985 +0xaca github.com/spf13/cobra.(*Command).ExecuteC(0xc0001fd808) go/pkg/mod/github.com/spf13/cobra@v1.8.1/command.go:1117 +0x3ff github.com/spf13/cobra.(*Command).ExecuteContextC(0xb6c5305?, {0xd098300?, 0xd5ca4a0?}) go/pkg/mod/github.com/spf13/cobra@v1.8.1/command.go:1050 +0x47 main.mainRun() gh-models/main.go:29 +0x26 main.main() gh-models/main.go:19 +0x13 --- cmd/root.go | 8 +++-- .../azuremodels/unauthenticated_client.go | 30 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 internal/azuremodels/unauthenticated_client.go diff --git a/cmd/root.go b/cmd/root.go index b25a47a5..ec225174 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -25,12 +25,16 @@ func NewRootCommand() *cobra.Command { terminal := term.FromEnv() out := terminal.Out() token, _ := auth.TokenForHost("github.com") + + var client azuremodels.Client + if token == "" { util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil + client = azuremodels.NewUnauthenticatedClient() + } else { + client = azuremodels.NewAzureClient(token) } - client := azuremodels.NewAzureClient(token) cfg := command.NewConfigWithTerminal(terminal, client) cmd.AddCommand(list.NewListCommand(cfg)) diff --git a/internal/azuremodels/unauthenticated_client.go b/internal/azuremodels/unauthenticated_client.go new file mode 100644 index 00000000..2f35aa89 --- /dev/null +++ b/internal/azuremodels/unauthenticated_client.go @@ -0,0 +1,30 @@ +package azuremodels + +import ( + "context" + "errors" +) + +// UnauthenticatedClient is for use by anonymous viewers to talk to the models API. +type UnauthenticatedClient struct { +} + +// NewUnauthenticatedClient contructs a new models API client for an anonymous viewer. +func NewUnauthenticatedClient() *UnauthenticatedClient { + return &UnauthenticatedClient{} +} + +// GetChatCompletionStream returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { + return nil, errors.New("not authenticated") +} + +// GetModelDetails returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + return nil, errors.New("not authenticated") +} + +// ListModels returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + return nil, errors.New("not authenticated") +} From 09cd4aeb406aa57325bec99df8f4e56debc9b1cd Mon Sep 17 00:00:00 2001 From: Sarah Vessels Date: Fri, 11 Oct 2024 15:51:50 -0500 Subject: [PATCH 28/28] List param names https: //github.com/github/gh-models/pull/15#discussion_r1797368843 Co-Authored-By: Christopher Schleiden --- internal/azuremodels/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go index 4e1b83f9..9681decd 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/client.go @@ -7,7 +7,7 @@ type Client interface { // GetChatCompletionStream returns a stream of chat completions using the given options. GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) // GetModelDetails returns the details of the specified model in a particular registry. - GetModelDetails(context.Context, string, string, string) (*ModelDetails, error) + GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) // ListModels returns a list of available models. ListModels(context.Context) ([]*ModelSummary, error) }