diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index cc5ebdce..eed82d4d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -2,19 +2,27 @@ name: "go-linter" on: pull_request: + types: [opened, synchronize, reopened] merge_group: workflow_dispatch: + push: + branches: + - 'main' permissions: contents: read +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: lint: strategy: fail-fast: false - runs-on: ubuntu-latest-xl + runs-on: ubuntu-latest env: - GOPROXY: https://goproxy.githubapp.com/mod,https://proxy.golang.org/,direct + GOPROXY: https://proxy.golang.org/,direct GOPRIVATE: "" GONOPROXY: "" GONOSUMDB: github.com/github/* @@ -24,9 +32,6 @@ jobs: go-version: ${{ vars.GOVERSION }} check-latest: true - uses: actions/checkout@v4 - - name: Configure Go private module access - run: | - echo "machine goproxy.githubapp.com login nobody password ${{ secrets.GOPROXY_TOKEN }}" >> $HOME/.netrc - name: Lint # This also does checkout, setup-go, and proxy setup. uses: github/go-linter@v1.2.1 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..3fe22e3a --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,46 @@ +name: "Build and test" + +on: + pull_request: + types: [opened, synchronize, reopened] + workflow_dispatch: + merge_group: + push: + branches: + - 'main' + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + env: + GOPROXY: https://proxy.golang.org/,direct + GOPRIVATE: "" + GONOPROXY: "" + GONOSUMDB: github.com/github/* + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: ${{ vars.GOVERSION }} + check-latest: true + - name: Verify go.sum is up to date + run: | + go mod tidy + git diff --exit-code go.sum + if [ $? -ne 0 ]; then + echo "Error: go.sum has changed, please run `go mod tidy` and commit the result" + exit 1 + fi + + - name: Build program + run: go build -v ./... + + - name: Run tests + run: go test -race -cover ./... diff --git a/cmd/list/list.go b/cmd/list/list.go index 6ab3088a..75437939 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -4,12 +4,10 @@ package list import ( "fmt" - "github.com/cli/go-gh/v2/pkg/auth" "github.com/cli/go-gh/v2/pkg/tableprinter" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" - "github.com/github/gh-models/pkg/util" + "github.com/github/gh-models/pkg/command" "github.com/mgutz/ansi" "github.com/spf13/cobra" ) @@ -19,24 +17,14 @@ var ( ) // NewListCommand returns a new command to list available GitHub models. -func NewListCommand() *cobra.Command { +func NewListCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "List available models", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - out := terminal.Out() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } - - client := azuremodels.NewClient(token) ctx := cmd.Context() - + client := cfg.Client models, err := client.ListModels(ctx) if err != nil { return err @@ -47,16 +35,13 @@ func NewListCommand() *cobra.Command { models = filterToChatModels(models) ux.SortModels(models) - isTTY := terminal.IsTerminalOutput() - - if isTTY { - util.WriteToOut(out, "\n") - util.WriteToOut(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) - util.WriteToOut(out, "\n") + if cfg.IsTerminalOutput { + cfg.WriteToOut("\n") + cfg.WriteToOut(fmt.Sprintf("Showing %d available chat models\n", len(models))) + cfg.WriteToOut("\n") } - width, _, _ := terminal.Size() - printer := tableprinter.New(out, isTTY, width) + printer := cfg.NewTablePrinter() printer.AddHeader([]string{"DISPLAY NAME", "MODEL NAME"}, tableprinter.WithColor(lightGrayUnderline)) printer.EndRow() diff --git a/cmd/list/list_test.go b/cmd/list/list_test.go new file mode 100644 index 00000000..60ceabed --- /dev/null +++ b/cmd/list/list_test.go @@ -0,0 +1,46 @@ +package list + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/stretchr/testify/require" +) + +func TestList(t *testing.T) { + t.Run("NewListCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + listCmd := NewListCommand(cfg) + + _, err := listCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + output := buf.String() + require.Contains(t, output, "Showing 1 available chat models") + require.Contains(t, output, "DISPLAY NAME") + require.Contains(t, output, "MODEL NAME") + require.Contains(t, output, modelSummary.FriendlyName) + require.Contains(t, output, modelSummary.Name) + }) +} diff --git a/cmd/root.go b/cmd/root.go index 3e8caf52..ec225174 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,9 +4,14 @@ package cmd import ( "strings" + "github.com/cli/go-gh/v2/pkg/auth" + "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/cmd/list" "github.com/github/gh-models/cmd/run" "github.com/github/gh-models/cmd/view" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" ) @@ -17,9 +22,24 @@ func NewRootCommand() *cobra.Command { Short: "GitHub Models extension", } - cmd.AddCommand(list.NewListCommand()) - cmd.AddCommand(run.NewRunCommand()) - cmd.AddCommand(view.NewViewCommand()) + terminal := term.FromEnv() + out := terminal.Out() + token, _ := auth.TokenForHost("github.com") + + var client azuremodels.Client + + if token == "" { + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + client = azuremodels.NewUnauthenticatedClient() + } else { + client = azuremodels.NewAzureClient(token) + } + + cfg := command.NewConfigWithTerminal(terminal, client) + + cmd.AddCommand(list.NewListCommand(cfg)) + cmd.AddCommand(run.NewRunCommand(cfg)) + cmd.AddCommand(view.NewViewCommand(cfg)) // Cobra doesn't have a way to specify a two word command (ie. "gh models"), so set a custom usage template // with `gh`` in it. Cobra will use this template for the root and all child commands. diff --git a/cmd/run/run.go b/cmd/run/run.go index bc12d038..7b542358 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -14,11 +14,10 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/briandowns/spinner" - "github.com/cli/go-gh/v2/pkg/auth" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -190,13 +189,13 @@ func isPipe(r io.Reader) bool { } // NewRunCommand returns a new gh command for running a model. -func NewRunCommand() *cobra.Command { +func NewRunCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", Short: "Run inference with the specified model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - cmdHandler := newRunCommandHandler(cmd, args) + cmdHandler := newRunCommandHandler(cmd, cfg, args) if cmdHandler == nil { return nil } @@ -307,7 +306,7 @@ func NewRunCommand() *cobra.Command { mp.UpdateRequest(&req) - sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.errOut)) + sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.cfg.ErrOut)) sp.Start() //nolint:gocritic,revive // TODO defer sp.Stop() @@ -340,7 +339,7 @@ func NewRunCommand() *cobra.Command { } } - util.WriteToOut(cmdHandler.out, "\n") + cmdHandler.writeToOut("\n") _, err = messageBuilder.WriteString("\n") if err != nil { return err @@ -366,30 +365,14 @@ func NewRunCommand() *cobra.Command { } type runCommandHandler struct { - ctx context.Context - terminal term.Term - out io.Writer - errOut io.Writer - client *azuremodels.Client - args []string + ctx context.Context + cfg *command.Config + client azuremodels.Client + args []string } -func newRunCommandHandler(cmd *cobra.Command, args []string) *runCommandHandler { - terminal := term.FromEnv() - out := terminal.Out() - token, _ := auth.TokenForHost("github.com") - if token == "" { - util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } - return &runCommandHandler{ - ctx: cmd.Context(), - terminal: terminal, - out: out, - args: args, - errOut: terminal.ErrOut(), - client: azuremodels.NewClient(token), - } +func newRunCommandHandler(cmd *cobra.Command, cfg *command.Config, args []string) *runCommandHandler { + return &runCommandHandler{ctx: cmd.Context(), cfg: cfg, client: cfg.Client, args: args} } func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) { @@ -464,23 +447,23 @@ func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCo } func (h *runCommandHandler) handleParametersPrompt(conversation Conversation, mp ModelParameters) { - util.WriteToOut(h.out, "Current parameters:\n") + h.writeToOut("Current parameters:\n") names := []string{"max-tokens", "temperature", "top-p"} for _, name := range names { - util.WriteToOut(h.out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + h.writeToOut(fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) } - util.WriteToOut(h.out, "\n") - util.WriteToOut(h.out, "System Prompt:\n") + h.writeToOut("\n") + h.writeToOut("System Prompt:\n") if conversation.systemPrompt != "" { - util.WriteToOut(h.out, " "+conversation.systemPrompt+"\n") + h.writeToOut(" " + conversation.systemPrompt + "\n") } else { - util.WriteToOut(h.out, " \n") + h.writeToOut(" \n") } } func (h *runCommandHandler) handleResetPrompt(conversation Conversation) { conversation.Reset() - util.WriteToOut(h.out, "Reset chat history\n") + h.writeToOut("Reset chat history\n") } func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { @@ -491,34 +474,34 @@ func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { err := mp.SetParameterByName(name, value) if err != nil { - util.WriteToOut(h.out, err.Error()+"\n") + h.writeToOut(err.Error() + "\n") return } - util.WriteToOut(h.out, "Set "+name+" to "+value+"\n") + h.writeToOut("Set " + name + " to " + value + "\n") } else { - util.WriteToOut(h.out, "Invalid /set syntax. Usage: /set \n") + h.writeToOut("Invalid /set syntax. Usage: /set \n") } } func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) Conversation { conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - util.WriteToOut(h.out, "Updated system prompt\n") + h.writeToOut("Updated system prompt\n") return conversation } func (h *runCommandHandler) handleHelpPrompt() { - util.WriteToOut(h.out, "Commands:\n") - util.WriteToOut(h.out, " /bye, /exit, /quit - Exit the chat\n") - util.WriteToOut(h.out, " /parameters - Show current model parameters\n") - util.WriteToOut(h.out, " /reset, /clear - Reset chat context\n") - util.WriteToOut(h.out, " /set - Set a model parameter\n") - util.WriteToOut(h.out, " /system-prompt - Set the system prompt\n") - util.WriteToOut(h.out, " /help - Show this help message\n") + h.writeToOut("Commands:\n") + h.writeToOut(" /bye, /exit, /quit - Exit the chat\n") + h.writeToOut(" /parameters - Show current model parameters\n") + h.writeToOut(" /reset, /clear - Reset chat context\n") + h.writeToOut(" /set - Set a model parameter\n") + h.writeToOut(" /system-prompt - Set the system prompt\n") + h.writeToOut(" /help - Show this help message\n") } func (h *runCommandHandler) handleUnrecognizedPrompt(prompt string) { - util.WriteToOut(h.out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") + h.writeToOut("Unknown command '" + prompt + "'. See /help for supported commands.\n") } func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder) error { @@ -530,20 +513,24 @@ func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice if err != nil { return err } - util.WriteToOut(h.out, *content) + h.writeToOut(*content) } else if choice.Message != nil && choice.Message.Content != nil { content := choice.Message.Content _, err := messageBuilder.WriteString(*content) if err != nil { return err } - util.WriteToOut(h.out, *content) + h.writeToOut(*content) } // Introduce a small delay in between response tokens to better simulate a conversation - if h.terminal.IsTerminalOutput() { + if h.cfg.IsTerminalOutput { time.Sleep(10 * time.Millisecond) } return nil } + +func (h *runCommandHandler) writeToOut(message string) { + h.cfg.WriteToOut(message) +} diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go new file mode 100644 index 00000000..92755745 --- /dev/null +++ b/cmd/run/run_test.go @@ -0,0 +1,62 @@ +package run + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestRun(t *testing.T) { + t.Run("NewRunCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + fakeMessageFromModel := "yes hello this is dog" + chatChoice := azuremodels.ChatChoice{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(fakeMessageFromModel), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + } + chatCompletion := azuremodels.ChatCompletion{Choices: []azuremodels.ChatChoice{chatChoice}} + chatResp := &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + } + getChatCompletionCallCount := 0 + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + getChatCompletionCallCount++ + return chatResp, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{modelSummary.Name, "this is my prompt"}) + + _, err := runCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + require.Equal(t, 1, getChatCompletionCallCount) + output := buf.String() + require.Contains(t, output, fakeMessageFromModel) + }) +} diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go index 63790f3c..6776c571 100644 --- a/cmd/view/model_printer.go +++ b/cmd/view/model_printer.go @@ -5,8 +5,8 @@ import ( "github.com/cli/cli/v2/pkg/markdown" "github.com/cli/go-gh/v2/pkg/tableprinter" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" "github.com/mgutz/ansi" ) @@ -21,17 +21,20 @@ type modelPrinter struct { terminalWidth int } -func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, terminal term.Term) modelPrinter { - width, _, _ := terminal.Size() - printer := tableprinter.New(terminal.Out(), terminal.IsTerminalOutput(), width) - return modelPrinter{modelSummary: summary, modelDetails: details, printer: printer, terminalWidth: width} +func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, cfg *command.Config) modelPrinter { + return modelPrinter{ + modelSummary: summary, + modelDetails: details, + printer: cfg.NewTablePrinter(), + terminalWidth: cfg.TerminalWidth, + } } func (p *modelPrinter) render() error { modelSummary := p.modelSummary if modelSummary != nil { p.printLabelledLine("Display name:", modelSummary.FriendlyName) - p.printLabelledLine("Summary name:", modelSummary.Name) + p.printLabelledLine("Model name:", modelSummary.Name) p.printLabelledLine("Publisher:", modelSummary.Publisher) p.printLabelledLine("Summary:", modelSummary.Summary) } diff --git a/cmd/view/view.go b/cmd/view/view.go index 777281d7..37e34e03 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -5,32 +5,21 @@ import ( "fmt" "github.com/AlecAivazis/survey/v2" - "github.com/cli/go-gh/v2/pkg/auth" - "github.com/cli/go-gh/v2/pkg/term" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/ux" - "github.com/github/gh-models/pkg/util" + "github.com/github/gh-models/pkg/command" "github.com/spf13/cobra" ) // NewViewCommand returns a new command to view details about a model. -func NewViewCommand() *cobra.Command { +func NewViewCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "view [model]", Short: "View details about a model", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - util.WriteToOut(terminal.Out(), "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } - - client := azuremodels.NewClient(token) ctx := cmd.Context() - + client := cfg.Client models, err := client.ListModels(ctx) if err != nil { return err @@ -73,7 +62,7 @@ func NewViewCommand() *cobra.Command { return err } - modelPrinter := newModelPrinter(modelSummary, modelDetails, terminal) + modelPrinter := newModelPrinter(modelSummary, modelDetails, cfg) err = modelPrinter.render() if err != nil { diff --git a/cmd/view/view_test.go b/cmd/view/view_test.go new file mode 100644 index 00000000..8348b31e --- /dev/null +++ b/cmd/view/view_test.go @@ -0,0 +1,89 @@ +package view + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/stretchr/testify/require" +) + +func TestView(t *testing.T) { + t.Run("NewViewCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + getModelDetailsCallCount := 0 + modelDetails := &azuremodels.ModelDetails{ + Description: "Fake description", + Evaluation: "Fake evaluation", + License: "MIT", + LicenseDescription: "This is a test license", + Tags: []string{"tag1", "tag2"}, + SupportedInputModalities: []string{"text", "carrier-pigeon"}, + SupportedOutputModalities: []string{"underwater-signals"}, + SupportedLanguages: []string{"English", "Spanish"}, + MaxOutputTokens: 123, + MaxInputTokens: 456, + RateLimitTier: "mediumish", + } + client.MockGetModelDetails = func(ctx context.Context, registryName, modelName, version string) (*azuremodels.ModelDetails, error) { + getModelDetailsCallCount++ + return modelDetails, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + viewCmd := NewViewCommand(cfg) + viewCmd.SetArgs([]string{modelSummary.Name}) + + _, err := viewCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + require.Equal(t, 1, getModelDetailsCallCount) + output := buf.String() + require.Contains(t, output, "Display name:") + require.Contains(t, output, modelSummary.FriendlyName) + require.Contains(t, output, "Model name:") + require.Contains(t, output, modelSummary.Name) + require.Contains(t, output, "Publisher:") + require.Contains(t, output, modelSummary.Publisher) + require.Contains(t, output, "Summary:") + require.Contains(t, output, modelSummary.Summary) + require.Contains(t, output, "Context:") + require.Contains(t, output, "up to 456 input tokens and 123 output tokens") + require.Contains(t, output, "Rate limit tier:") + require.Contains(t, output, "mediumish") + require.Contains(t, output, "Tags:") + require.Contains(t, output, "tag1, tag2") + require.Contains(t, output, "Supported input types:") + require.Contains(t, output, "text, carrier-pigeon") + require.Contains(t, output, "Supported output types:") + require.Contains(t, output, "underwater-signals") + require.Contains(t, output, "Supported languages:") + require.Contains(t, output, "English, Spanish") + require.Contains(t, output, "License:") + require.Contains(t, output, modelDetails.License) + require.Contains(t, output, "License description:") + require.Contains(t, output, modelDetails.LicenseDescription) + require.Contains(t, output, "Description:") + require.Contains(t, output, modelDetails.Description) + require.Contains(t, output, "Evaluation:") + require.Contains(t, output, modelDetails.Evaluation) + }) +} diff --git a/go.mod b/go.mod index 27a21d5a..9b73ace2 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.9.0 golang.org/x/text v0.18.0 ) @@ -24,6 +25,7 @@ require ( github.com/charmbracelet/x/exp/term v0.0.0-20240425164147-ba2a9512b05f // indirect github.com/cli/safeexec v1.0.1 // indirect github.com/cli/shurcooL-graphql v0.0.4 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dlclark/regexp2 v1.4.0 // indirect github.com/fatih/color v1.16.0 // indirect github.com/gorilla/css v1.0.0 // indirect @@ -39,6 +41,7 @@ require ( github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect github.com/yuin/goldmark v1.5.4 // indirect diff --git a/internal/azuremodels/azure_client.go b/internal/azuremodels/azure_client.go new file mode 100644 index 00000000..7d58da70 --- /dev/null +++ b/internal/azuremodels/azure_client.go @@ -0,0 +1,279 @@ +// Package azuremodels provides a client for interacting with the Azure models API. +package azuremodels + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/cli/go-gh/v2/pkg/api" + "github.com/github/gh-models/internal/sse" + "golang.org/x/text/language" + "golang.org/x/text/language/display" +) + +// AzureClient provides a client for interacting with the Azure models API. +type AzureClient struct { + client *http.Client + token string +} + +const ( + prodInferenceURL = "https://models.inference.ai.azure.com/chat/completions" + azureAiStudioURL = "https://api.catalog.azureml.ms" + prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models" +) + +// NewAzureClient returns a new Azure client using the given auth token. +func NewAzureClient(authToken string) *AzureClient { + httpClient, _ := api.DefaultHTTPClient() + return &AzureClient{ + client: httpClient, + token: authToken, + } +} + +// GetChatCompletionStream returns a stream of chat completions using the given options. +func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { + // Check if the model name is `o1-mini` or `o1-preview` + if req.Model == "o1-mini" || req.Model == "o1-preview" { + req.Stream = false + } else { + req.Stream = true + } + + bodyBytes, err := json.Marshal(req) + if err != nil { + return nil, err + } + + body := bytes.NewReader(bodyBytes) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodInferenceURL, body) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Authorization", "Bearer "+c.token) + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + // If we aren't going to return an SSE stream, then ensure the response body is closed. + defer resp.Body.Close() + return nil, c.handleHTTPError(resp) + } + + var chatCompletionResponse ChatCompletionResponse + + if req.Stream { + // Handle streamed response + chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) + } else { + var completion ChatCompletion + if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { + return nil, err + } + + // Create a mock reader that returns the decoded completion + mockReader := sse.NewMockEventReader([]ChatCompletion{completion}) + chatCompletionResponse.Reader = mockReader + } + + return &chatCompletionResponse, nil +} + +// GetModelDetails returns the details of the specified model in a particular registry. +func (c *AzureClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, c.handleHTTPError(resp) + } + + decoder := json.NewDecoder(resp.Body) + decoder.UseNumber() + + var detailsResponse modelCatalogDetailsResponse + err = decoder.Decode(&detailsResponse) + if err != nil { + return nil, err + } + + modelDetails := &ModelDetails{ + Description: detailsResponse.Description, + License: detailsResponse.License, + LicenseDescription: detailsResponse.LicenseDescription, + Notes: detailsResponse.Notes, + Tags: lowercaseStrings(detailsResponse.Keywords), + Evaluation: detailsResponse.Evaluation, + } + + modelLimits := detailsResponse.ModelLimits + if modelLimits != nil { + modelDetails.SupportedInputModalities = modelLimits.SupportedInputModalities + modelDetails.SupportedOutputModalities = modelLimits.SupportedOutputModalities + modelDetails.SupportedLanguages = convertLanguageCodesToNames(modelLimits.SupportedLanguages) + + textLimits := modelLimits.TextLimits + if textLimits != nil { + modelDetails.MaxOutputTokens = textLimits.MaxOutputTokens + modelDetails.MaxInputTokens = textLimits.InputContextWindow + } + } + + playgroundLimits := detailsResponse.PlaygroundLimits + if playgroundLimits != nil { + modelDetails.RateLimitTier = playgroundLimits.RateLimitTier + } + + return modelDetails, nil +} + +func convertLanguageCodesToNames(input []string) []string { + output := make([]string, len(input)) + english := display.English.Languages() + for i, code := range input { + tag := language.MustParse(code) + output[i] = english.Name(tag) + } + return output +} + +func lowercaseStrings(input []string) []string { + output := make([]string, len(input)) + for i, s := range input { + output[i] = strings.ToLower(s) + } + return output +} + +// ListModels returns a list of available models. +func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + body := bytes.NewReader([]byte(` + { + "filters": [ + { "field": "freePlayground", "values": ["true"], "operator": "eq"}, + { "field": "labels", "values": ["latest"], "operator": "eq"} + ], + "order": [ + { "field": "displayName", "direction": "asc" } + ] + } + `)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodModelsURL, body) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, c.handleHTTPError(resp) + } + + decoder := json.NewDecoder(resp.Body) + decoder.UseNumber() + + var searchResponse modelCatalogSearchResponse + err = decoder.Decode(&searchResponse) + if err != nil { + return nil, err + } + + models := make([]*ModelSummary, 0, len(searchResponse.Summaries)) + for _, summary := range searchResponse.Summaries { + inferenceTask := "" + if len(summary.InferenceTasks) > 0 { + inferenceTask = summary.InferenceTasks[0] + } + + models = append(models, &ModelSummary{ + ID: summary.AssetID, + Name: summary.Name, + FriendlyName: summary.DisplayName, + Task: inferenceTask, + Publisher: summary.Publisher, + Summary: summary.Summary, + Version: summary.Version, + RegistryName: summary.RegistryName, + }) + } + + return models, nil +} + +func (c *AzureClient) handleHTTPError(resp *http.Response) error { + sb := strings.Builder{} + var err error + + switch resp.StatusCode { + case http.StatusUnauthorized: + _, err = sb.WriteString("unauthorized") + if err != nil { + return err + } + + case http.StatusBadRequest: + _, err = sb.WriteString("bad request") + if err != nil { + return err + } + + default: + _, err = sb.WriteString("unexpected response from the server: " + resp.Status) + if err != nil { + return err + } + } + + body, _ := io.ReadAll(resp.Body) + if len(body) > 0 { + _, err = sb.WriteString("\n") + if err != nil { + return err + } + + _, err = sb.Write(body) + if err != nil { + return err + } + + _, err = sb.WriteString("\n") + if err != nil { + return err + } + } + + return errors.New(sb.String()) +} diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go index a4b60d39..9681decd 100644 --- a/internal/azuremodels/client.go +++ b/internal/azuremodels/client.go @@ -1,279 +1,13 @@ -// Package azuremodels provides a client for interacting with the Azure models API. package azuremodels -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strings" +import "context" - "github.com/cli/go-gh/v2/pkg/api" - "github.com/github/gh-models/internal/sse" - "golang.org/x/text/language" - "golang.org/x/text/language/display" -) - -// Client provides a client for interacting with the Azure models API. -type Client struct { - client *http.Client - token string -} - -const ( - prodInferenceURL = "https://models.inference.ai.azure.com/chat/completions" - azureAiStudioURL = "https://api.catalog.azureml.ms" - prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models" -) - -// NewClient returns a new client using the given auth token. -func NewClient(authToken string) *Client { - httpClient, _ := api.DefaultHTTPClient() - return &Client{ - client: httpClient, - token: authToken, - } -} - -// GetChatCompletionStream returns a stream of chat completions for the given request. -func (c *Client) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { - // Check if the model name is `o1-mini` or `o1-preview` - if req.Model == "o1-mini" || req.Model == "o1-preview" { - req.Stream = false - } else { - req.Stream = true - } - - bodyBytes, err := json.Marshal(req) - if err != nil { - return nil, err - } - - body := bytes.NewReader(bodyBytes) - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodInferenceURL, body) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Authorization", "Bearer "+c.token) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - // If we aren't going to return an SSE stream, then ensure the response body is closed. - defer resp.Body.Close() - return nil, c.handleHTTPError(resp) - } - - var chatCompletionResponse ChatCompletionResponse - - if req.Stream { - // Handle streamed response - chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) - } else { - var completion ChatCompletion - if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { - return nil, err - } - - // Create a mock reader that returns the decoded completion - mockReader := sse.NewMockEventReader([]ChatCompletion{completion}) - chatCompletionResponse.Reader = mockReader - } - - return &chatCompletionResponse, nil -} - -// GetModelDetails returns the details of the specified model in a prticular registry. -func (c *Client) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { - url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, err - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, c.handleHTTPError(resp) - } - - decoder := json.NewDecoder(resp.Body) - decoder.UseNumber() - - var detailsResponse modelCatalogDetailsResponse - err = decoder.Decode(&detailsResponse) - if err != nil { - return nil, err - } - - modelDetails := &ModelDetails{ - Description: detailsResponse.Description, - License: detailsResponse.License, - LicenseDescription: detailsResponse.LicenseDescription, - Notes: detailsResponse.Notes, - Tags: lowercaseStrings(detailsResponse.Keywords), - Evaluation: detailsResponse.Evaluation, - } - - modelLimits := detailsResponse.ModelLimits - if modelLimits != nil { - modelDetails.SupportedInputModalities = modelLimits.SupportedInputModalities - modelDetails.SupportedOutputModalities = modelLimits.SupportedOutputModalities - modelDetails.SupportedLanguages = convertLanguageCodesToNames(modelLimits.SupportedLanguages) - - textLimits := modelLimits.TextLimits - if textLimits != nil { - modelDetails.MaxOutputTokens = textLimits.MaxOutputTokens - modelDetails.MaxInputTokens = textLimits.InputContextWindow - } - } - - playgroundLimits := detailsResponse.PlaygroundLimits - if playgroundLimits != nil { - modelDetails.RateLimitTier = playgroundLimits.RateLimitTier - } - - return modelDetails, nil -} - -func convertLanguageCodesToNames(input []string) []string { - output := make([]string, len(input)) - english := display.English.Languages() - for i, code := range input { - tag := language.MustParse(code) - output[i] = english.Name(tag) - } - return output -} - -func lowercaseStrings(input []string) []string { - output := make([]string, len(input)) - for i, s := range input { - output[i] = strings.ToLower(s) - } - return output -} - -// ListModels returns a list of available models. -func (c *Client) ListModels(ctx context.Context) ([]*ModelSummary, error) { - body := bytes.NewReader([]byte(` - { - "filters": [ - { "field": "freePlayground", "values": ["true"], "operator": "eq"}, - { "field": "labels", "values": ["latest"], "operator": "eq"} - ], - "order": [ - { "field": "displayName", "direction": "asc" } - ] - } - `)) - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodModelsURL, body) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, err - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, c.handleHTTPError(resp) - } - - decoder := json.NewDecoder(resp.Body) - decoder.UseNumber() - - var searchResponse modelCatalogSearchResponse - err = decoder.Decode(&searchResponse) - if err != nil { - return nil, err - } - - models := make([]*ModelSummary, 0, len(searchResponse.Summaries)) - for _, summary := range searchResponse.Summaries { - inferenceTask := "" - if len(summary.InferenceTasks) > 0 { - inferenceTask = summary.InferenceTasks[0] - } - - models = append(models, &ModelSummary{ - ID: summary.AssetID, - Name: summary.Name, - FriendlyName: summary.DisplayName, - Task: inferenceTask, - Publisher: summary.Publisher, - Summary: summary.Summary, - Version: summary.Version, - RegistryName: summary.RegistryName, - }) - } - - return models, nil -} - -func (c *Client) handleHTTPError(resp *http.Response) error { - sb := strings.Builder{} - var err error - - switch resp.StatusCode { - case http.StatusUnauthorized: - _, err = sb.WriteString("unauthorized") - if err != nil { - return err - } - - case http.StatusBadRequest: - _, err = sb.WriteString("bad request") - if err != nil { - return err - } - - default: - _, err = sb.WriteString("unexpected response from the server: " + resp.Status) - if err != nil { - return err - } - } - - body, _ := io.ReadAll(resp.Body) - if len(body) > 0 { - _, err = sb.WriteString("\n") - if err != nil { - return err - } - - _, err = sb.Write(body) - if err != nil { - return err - } - - _, err = sb.WriteString("\n") - if err != nil { - return err - } - } - - return errors.New(sb.String()) +// Client represents a client for interacting with an API about models. +type Client interface { + // GetChatCompletionStream returns a stream of chat completions using the given options. + GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + // GetModelDetails returns the details of the specified model in a particular registry. + GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) + // ListModels returns a list of available models. + ListModels(context.Context) ([]*ModelSummary, error) } diff --git a/internal/azuremodels/mock_client.go b/internal/azuremodels/mock_client.go new file mode 100644 index 00000000..c15cfb6d --- /dev/null +++ b/internal/azuremodels/mock_client.go @@ -0,0 +1,43 @@ +package azuremodels + +import ( + "context" + "errors" +) + +// MockClient provides a client for interacting with the Azure models API in tests. +type MockClient struct { + MockGetChatCompletionStream func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + MockGetModelDetails func(context.Context, string, string, string) (*ModelDetails, error) + MockListModels func(context.Context) ([]*ModelSummary, error) +} + +// NewMockClient returns a new mock client for stubbing out interactions with the models API. +func NewMockClient() *MockClient { + return &MockClient{ + MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) { + return nil, errors.New("GetChatCompletionStream not implemented") + }, + MockGetModelDetails: func(context.Context, string, string, string) (*ModelDetails, error) { + return nil, errors.New("GetModelDetails not implemented") + }, + MockListModels: func(context.Context) ([]*ModelSummary, error) { + return nil, errors.New("ListModels not implemented") + }, + } +} + +// GetChatCompletionStream calls the mocked function for getting a stream of chat completions for the given request. +func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { + return c.MockGetChatCompletionStream(ctx, opt) +} + +// GetModelDetails calls the mocked function for getting the details of the specified model in a particular registry. +func (c *MockClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + return c.MockGetModelDetails(ctx, registry, modelName, version) +} + +// ListModels calls the mocked function for getting a list of available models. +func (c *MockClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + return c.MockListModels(ctx) +} diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index 98138faf..d8d5a52d 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -36,7 +36,8 @@ type ChatCompletionOptions struct { TopP *float64 `json:"top_p,omitempty"` } -type chatChoiceMessage struct { +// ChatChoiceMessage is a message from a choice in a chat conversation. +type ChatChoiceMessage struct { Content *string `json:"content,omitempty"` Role *string `json:"role,omitempty"` } @@ -51,7 +52,7 @@ type ChatChoice struct { Delta *chatChoiceDelta `json:"delta,omitempty"` FinishReason string `json:"finish_reason"` Index int32 `json:"index"` - Message *chatChoiceMessage `json:"message,omitempty"` + Message *ChatChoiceMessage `json:"message,omitempty"` } // ChatCompletion represents a chat completion. diff --git a/internal/azuremodels/unauthenticated_client.go b/internal/azuremodels/unauthenticated_client.go new file mode 100644 index 00000000..2f35aa89 --- /dev/null +++ b/internal/azuremodels/unauthenticated_client.go @@ -0,0 +1,30 @@ +package azuremodels + +import ( + "context" + "errors" +) + +// UnauthenticatedClient is for use by anonymous viewers to talk to the models API. +type UnauthenticatedClient struct { +} + +// NewUnauthenticatedClient contructs a new models API client for an anonymous viewer. +func NewUnauthenticatedClient() *UnauthenticatedClient { + return &UnauthenticatedClient{} +} + +// GetChatCompletionStream returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { + return nil, errors.New("not authenticated") +} + +// GetModelDetails returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + return nil, errors.New("not authenticated") +} + +// ListModels returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + return nil, errors.New("not authenticated") +} diff --git a/pkg/command/config.go b/pkg/command/config.go new file mode 100644 index 00000000..36296b44 --- /dev/null +++ b/pkg/command/config.go @@ -0,0 +1,52 @@ +// Package command provides shared configuration for sub-commands in the gh-models extension. +package command + +import ( + "io" + + "github.com/cli/go-gh/v2/pkg/tableprinter" + "github.com/cli/go-gh/v2/pkg/term" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/util" +) + +// Config represents configurable settings for a command. +type Config struct { + // Out is where standard output is written. + Out io.Writer + // ErrOut is where error output is written. + ErrOut io.Writer + // Client is the client for interacting with the models service. + Client azuremodels.Client + // IsTerminalOutput is true if the output should be formatted for a terminal. + IsTerminalOutput bool + // TerminalWidth is the width of the terminal. + TerminalWidth int +} + +// NewConfig returns a new command configuration. +func NewConfig(out, errOut io.Writer, client azuremodels.Client, isTerminalOutput bool, width int) *Config { + return &Config{Out: out, ErrOut: errOut, Client: client, IsTerminalOutput: isTerminalOutput, TerminalWidth: width} +} + +// NewConfigWithTerminal returns a new command configuration using the given terminal. +func NewConfigWithTerminal(terminal term.Term, client azuremodels.Client) *Config { + width, _, _ := terminal.Size() + return &Config{ + Out: terminal.Out(), + ErrOut: terminal.ErrOut(), + Client: client, + IsTerminalOutput: terminal.IsTerminalOutput(), + TerminalWidth: width, + } +} + +// NewTablePrinter initializes a table printer with terminal mode and terminal width. +func (c *Config) NewTablePrinter() tableprinter.TablePrinter { + return tableprinter.New(c.Out, c.IsTerminalOutput, c.TerminalWidth) +} + +// WriteToOut writes a message to the configured stdout writer. +func (c *Config) WriteToOut(message string) { + util.WriteToOut(c.Out, message) +}