diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3fe22e3a..7078b8ec 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: ${{ vars.GOVERSION }} + go-version: ">=1.22" check-latest: true - name: Verify go.sum is up to date run: | @@ -43,4 +43,6 @@ jobs: run: go build -v ./... - name: Run tests - run: go test -race -cover ./... + run: | + go version + go test -race -cover ./... diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4cdbaa43..c8bb608b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,7 +14,7 @@ Please note that this project is released with a [Contributor Code of Conduct](C These are one time installations required to be able to test your changes locally as part of the pull request (PR) submission process. -1. Install Go [through download](https://go.dev/doc/install) | [through Homebrew](https://formulae.brew.sh/formula/go) and ensure it's a least version 1.22 +1. Install Go [through download](https://go.dev/doc/install) | [through Homebrew](https://formulae.brew.sh/formula/go) and ensure it's at least version 1.22 ## Submitting a pull request diff --git a/DEV.md b/DEV.md index 9f9c2ee3..36c44fd1 100644 --- a/DEV.md +++ b/DEV.md @@ -1,6 +1,6 @@ -## Developing +# Developing -### Prerequisites +## Prerequisites The extension requires the [`gh` CLI](https://cli.github.com/) to be installed and added to the `PATH`. Users must also authenticate via `gh auth` before using the extension. @@ -12,12 +12,12 @@ $ go version go version go1.22.x ``` -### Building +## Building To build the project, run `script/build`. After building, you can run the binary locally, for example: `./gh-models list`. -### Testing +## Testing To run lint tests, unit tests, and other Go-related checks before submitting a pull request, use: @@ -34,7 +34,7 @@ make vet # to find suspicious constructs make tidy # to keep dependencies up-to-date ``` -### Releasing +## Releasing When upgrading or installing the extension using `gh extension upgrade github/gh-models` or `gh extension install github/gh-models`, the latest release will be pulled, not the latest commit. Therefore, all diff --git a/cmd/list/list_test.go b/cmd/list/list_test.go index 60ceabed..a0c462ad 100644 --- a/cmd/list/list_test.go +++ b/cmd/list/list_test.go @@ -43,4 +43,19 @@ func TestList(t *testing.T) { require.Contains(t, output, modelSummary.FriendlyName) require.Contains(t, output, modelSummary.Name) }) + + t.Run("--help prints usage info", func(t *testing.T) { + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + listCmd := NewListCommand(nil) + listCmd.SetOut(outBuf) + listCmd.SetErr(errBuf) + listCmd.SetArgs([]string{"--help"}) + + err := listCmd.Help() + + require.NoError(t, err) + require.Contains(t, outBuf.String(), "List available models") + require.Empty(t, errBuf.String()) + }) } diff --git a/cmd/root.go b/cmd/root.go index ec225174..f3c2bd58 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,7 +32,12 @@ func NewRootCommand() *cobra.Command { util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") client = azuremodels.NewUnauthenticatedClient() } else { - client = azuremodels.NewAzureClient(token) + var err error + client, err = azuremodels.NewDefaultAzureClient(token) + if err != nil { + util.WriteToOut(terminal.ErrOut(), "Error creating Azure client: "+err.Error()) + return nil + } } cfg := command.NewConfigWithTerminal(terminal, client) diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 00000000..d05b1cdd --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "bytes" + "regexp" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRoot(t *testing.T) { + t.Run("usage info describes sub-commands", func(t *testing.T) { + buf := new(bytes.Buffer) + rootCmd := NewRootCommand() + rootCmd.SetOut(buf) + + err := rootCmd.Help() + + require.NoError(t, err) + output := buf.String() + require.Regexp(t, regexp.MustCompile(`Usage:\n\s+gh models \[command\]`), output) + require.Regexp(t, regexp.MustCompile(`list\s+List available models`), output) + require.Regexp(t, regexp.MustCompile(`run\s+Run inference with the specified model`), output) + require.Regexp(t, regexp.MustCompile(`view\s+View details about a model`), output) + }) +} diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 92755745..a9f8da17 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -3,6 +3,7 @@ package run import ( "bytes" "context" + "regexp" "testing" "github.com/github/gh-models/internal/azuremodels" @@ -59,4 +60,24 @@ func TestRun(t *testing.T) { output := buf.String() require.Contains(t, output, fakeMessageFromModel) }) + + t.Run("--help prints usage info", func(t *testing.T) { + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + runCmd := NewRunCommand(nil) + runCmd.SetOut(outBuf) + runCmd.SetErr(errBuf) + runCmd.SetArgs([]string{"--help"}) + + err := runCmd.Help() + + require.NoError(t, err) + output := outBuf.String() + require.Contains(t, output, "Run inference with the specified model") + require.Regexp(t, regexp.MustCompile(`--max-tokens string\s+Limit the maximum tokens for the model response\.`), output) + require.Regexp(t, regexp.MustCompile(`--system-prompt string\s+Prompt the system\.`), output) + require.Regexp(t, regexp.MustCompile(`--temperature string\s+Controls randomness in the response, use lower to be more deterministic\.`), output) + require.Regexp(t, regexp.MustCompile(`--top-p string\s+Controls text diversity by selecting the most probable words until a set probability is reached\.`), output) + require.Empty(t, errBuf.String()) + }) } diff --git a/cmd/view/view_test.go b/cmd/view/view_test.go index 8348b31e..537c968b 100644 --- a/cmd/view/view_test.go +++ b/cmd/view/view_test.go @@ -86,4 +86,19 @@ func TestView(t *testing.T) { require.Contains(t, output, "Evaluation:") require.Contains(t, output, modelDetails.Evaluation) }) + + t.Run("--help prints usage info", func(t *testing.T) { + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + viewCmd := NewViewCommand(nil) + viewCmd.SetOut(outBuf) + viewCmd.SetErr(errBuf) + viewCmd.SetArgs([]string{"--help"}) + + err := viewCmd.Help() + + require.NoError(t, err) + require.Contains(t, outBuf.String(), "View details about a model") + require.Empty(t, errBuf.String()) + }) } diff --git a/internal/azuremodels/azure_client.go b/internal/azuremodels/azure_client.go index 14dcd2e1..4014a1aa 100644 --- a/internal/azuremodels/azure_client.go +++ b/internal/azuremodels/azure_client.go @@ -21,21 +21,22 @@ import ( type AzureClient struct { client *http.Client token string + cfg *AzureClientConfig } -const ( - prodInferenceURL = "https://models.inference.ai.azure.com/chat/completions" - azureAiStudioURL = "https://api.catalog.azureml.ms" - prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models" -) - -// NewAzureClient returns a new Azure client using the given auth token. -func NewAzureClient(authToken string) *AzureClient { - httpClient, _ := api.DefaultHTTPClient() - return &AzureClient{ - client: httpClient, - token: authToken, +// NewDefaultAzureClient returns a new Azure client using the given auth token using default API URLs. +func NewDefaultAzureClient(authToken string) (*AzureClient, error) { + httpClient, err := api.DefaultHTTPClient() + if err != nil { + return nil, err } + cfg := NewDefaultAzureClientConfig() + return &AzureClient{client: httpClient, token: authToken, cfg: cfg}, nil +} + +// NewAzureClient returns a new Azure client using the given HTTP client, configuration, and auth token. +func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientConfig) *AzureClient { + return &AzureClient{client: httpClient, token: authToken, cfg: cfg} } // GetChatCompletionStream returns a stream of chat completions using the given options. @@ -54,7 +55,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl body := bytes.NewReader(bodyBytes) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodInferenceURL, body) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.InferenceURL, body) if err != nil { return nil, err } @@ -99,7 +100,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl // GetModelDetails returns the details of the specified model in a particular registry. func (c *AzureClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { - url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) + url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", c.cfg.AzureAiStudioURL, registry, modelName, version) httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return nil, err @@ -189,7 +190,7 @@ func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { } `)) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodModelsURL, body) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.ModelsURL, body) if err != nil { return nil, err } diff --git a/internal/azuremodels/azure_client_config.go b/internal/azuremodels/azure_client_config.go new file mode 100644 index 00000000..7675aa63 --- /dev/null +++ b/internal/azuremodels/azure_client_config.go @@ -0,0 +1,23 @@ +package azuremodels + +const ( + defaultInferenceURL = "https://models.inference.ai.azure.com/chat/completions" + defaultAzureAiStudioURL = "https://api.catalog.azureml.ms" + defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models" +) + +// AzureClientConfig represents configurable settings for the Azure client. +type AzureClientConfig struct { + InferenceURL string + AzureAiStudioURL string + ModelsURL string +} + +// NewDefaultAzureClientConfig returns a new AzureClientConfig with default values for API URLs. +func NewDefaultAzureClientConfig() *AzureClientConfig { + return &AzureClientConfig{ + InferenceURL: defaultInferenceURL, + AzureAiStudioURL: defaultAzureAiStudioURL, + ModelsURL: defaultModelsURL, + } +} diff --git a/internal/azuremodels/azure_client_test.go b/internal/azuremodels/azure_client_test.go new file mode 100644 index 00000000..bd69c43b --- /dev/null +++ b/internal/azuremodels/azure_client_test.go @@ -0,0 +1,68 @@ +package azuremodels + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAzureClient(t *testing.T) { + ctx := context.Background() + + t.Run("GetModelDetails happy path", func(t *testing.T) { + registry := "foo" + modelName := "bar" + version := "baz" + textLimits := &modelCatalogTextLimits{MaxOutputTokens: 8675309, InputContextWindow: 3} + modelLimits := &modelCatalogLimits{ + SupportedInputModalities: []string{"books", "VHS"}, + SupportedOutputModalities: []string{"watercolor paintings"}, + SupportedLanguages: []string{"fr", "zh"}, + TextLimits: textLimits, + } + playgroundLimits := &modelCatalogPlaygroundLimits{RateLimitTier: "big-ish"} + catalogDetails := &modelCatalogDetailsResponse{ + Description: "some model description", + License: "My Favorite License", + LicenseDescription: "This is a test license", + Notes: "You aren't gonna believe these notes.", + Keywords: []string{"Tag1", "TAG2"}, + Evaluation: "This model is the best.", + ModelLimits: modelLimits, + PlaygroundLimits: playgroundLimits, + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + require.Equal(t, "/asset-gallery/v1.0/"+registry+"/models/"+modelName+"/version/"+version, r.URL.Path) + + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(catalogDetails) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{AzureAiStudioURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, "fake-token-123abc", cfg) + + details, err := client.GetModelDetails(ctx, registry, modelName, version) + + require.NoError(t, err) + require.NotNil(t, details) + require.Equal(t, catalogDetails.Description, details.Description) + require.Equal(t, catalogDetails.License, details.License) + require.Equal(t, catalogDetails.LicenseDescription, details.LicenseDescription) + require.Equal(t, catalogDetails.Notes, details.Notes) + require.Equal(t, []string{"tag1", "tag2"}, details.Tags) + require.Equal(t, catalogDetails.Evaluation, details.Evaluation) + require.Equal(t, modelLimits.SupportedInputModalities, details.SupportedInputModalities) + require.Equal(t, modelLimits.SupportedOutputModalities, details.SupportedOutputModalities) + require.Equal(t, []string{"French", "Chinese"}, details.SupportedLanguages) + require.Equal(t, textLimits.MaxOutputTokens, details.MaxOutputTokens) + require.Equal(t, textLimits.InputContextWindow, details.MaxInputTokens) + require.Equal(t, playgroundLimits.RateLimitTier, details.RateLimitTier) + }) +} diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index d8d5a52d..2cd4d2af 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -98,37 +98,43 @@ func (m *ModelSummary) HasName(name string) bool { return strings.EqualFold(m.FriendlyName, name) || strings.EqualFold(m.Name, name) } +type modelCatalogTextLimits struct { + MaxOutputTokens int `json:"maxOutputTokens"` + InputContextWindow int `json:"inputContextWindow"` +} + +type modelCatalogLimits struct { + SupportedLanguages []string `json:"supportedLanguages"` + TextLimits *modelCatalogTextLimits `json:"textLimits"` + SupportedInputModalities []string `json:"supportedInputModalities"` + SupportedOutputModalities []string `json:"supportedOutputModalities"` +} + +type modelCatalogPlaygroundLimits struct { + RateLimitTier string `json:"rateLimitTier"` +} + type modelCatalogDetailsResponse struct { - AssetID string `json:"assetId"` - Name string `json:"name"` - DisplayName string `json:"displayName"` - Publisher string `json:"publisher"` - Version string `json:"version"` - RegistryName string `json:"registryName"` - Evaluation string `json:"evaluation"` - Summary string `json:"summary"` - Description string `json:"description"` - License string `json:"license"` - LicenseDescription string `json:"licenseDescription"` - Notes string `json:"notes"` - Keywords []string `json:"keywords"` - InferenceTasks []string `json:"inferenceTasks"` - FineTuningTasks []string `json:"fineTuningTasks"` - Labels []string `json:"labels"` - TradeRestricted bool `json:"tradeRestricted"` - CreatedTime string `json:"createdTime"` - PlaygroundLimits *struct { - RateLimitTier string `json:"rateLimitTier"` - } `json:"playgroundLimits"` - ModelLimits *struct { - SupportedLanguages []string `json:"supportedLanguages"` - TextLimits *struct { - MaxOutputTokens int `json:"maxOutputTokens"` - InputContextWindow int `json:"inputContextWindow"` - } `json:"textLimits"` - SupportedInputModalities []string `json:"supportedInputModalities"` - SupportedOutputModalities []string `json:"supportedOutputModalities"` - } `json:"modelLimits"` + AssetID string `json:"assetId"` + Name string `json:"name"` + DisplayName string `json:"displayName"` + Publisher string `json:"publisher"` + Version string `json:"version"` + RegistryName string `json:"registryName"` + Evaluation string `json:"evaluation"` + Summary string `json:"summary"` + Description string `json:"description"` + License string `json:"license"` + LicenseDescription string `json:"licenseDescription"` + Notes string `json:"notes"` + Keywords []string `json:"keywords"` + InferenceTasks []string `json:"inferenceTasks"` + FineTuningTasks []string `json:"fineTuningTasks"` + Labels []string `json:"labels"` + TradeRestricted bool `json:"tradeRestricted"` + CreatedTime string `json:"createdTime"` + PlaygroundLimits *modelCatalogPlaygroundLimits `json:"playgroundLimits"` + ModelLimits *modelCatalogLimits `json:"modelLimits"` } // ModelDetails includes detailed information about a model. diff --git a/internal/ux/sorting_test.go b/internal/ux/sorting_test.go new file mode 100644 index 00000000..35169446 --- /dev/null +++ b/internal/ux/sorting_test.go @@ -0,0 +1,24 @@ +package ux + +import ( + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/stretchr/testify/require" +) + +func TestSorting(t *testing.T) { + t.Run("SortModels sorts given slice in-place by friendly name, case-insensitive", func(t *testing.T) { + modelA := &azuremodels.ModelSummary{Name: "z", FriendlyName: "AARDVARK"} + modelB := &azuremodels.ModelSummary{Name: "y", FriendlyName: "betta"} + modelC := &azuremodels.ModelSummary{Name: "x", FriendlyName: "Cat"} + models := []*azuremodels.ModelSummary{modelB, modelA, modelC} + + SortModels(models) + + require.Equal(t, 3, len(models)) + require.Equal(t, "AARDVARK", models[0].FriendlyName) + require.Equal(t, "betta", models[1].FriendlyName) + require.Equal(t, "Cat", models[2].FriendlyName) + }) +}