Skip to content
Merged
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ vars.GOVERSION }}
go-version: ">=1.22"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifying the same Go version here as we do in our dev docs.

check-latest: true
- name: Verify go.sum is up to date
run: |
Expand All @@ -43,4 +43,6 @@ jobs:
run: go build -v ./...

- name: Run tests
run: go test -race -cover ./...
run: |
go version
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outputting the version for visibility when looking at test results, in case the version affects the results.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement! We could include this in a script/test as well, for consistency with other repos.

go test -race -cover ./...
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Please note that this project is released with a [Contributor Code of Conduct](C

These are one time installations required to be able to test your changes locally as part of the pull request (PR) submission process.

1. Install Go [through download](https://go.dev/doc/install) | [through Homebrew](https://formulae.brew.sh/formula/go) and ensure it's a least version 1.22
1. Install Go [through download](https://go.dev/doc/install) | [through Homebrew](https://formulae.brew.sh/formula/go) and ensure it's at least version 1.22

## Submitting a pull request

Expand Down
10 changes: 5 additions & 5 deletions DEV.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Developing
# Developing

### Prerequisites
## Prerequisites

The extension requires the [`gh` CLI](https://cli.github.com/) to be installed and added to the `PATH`. Users must also
authenticate via `gh auth` before using the extension.
Expand All @@ -12,12 +12,12 @@ $ go version
go version go1.22.x <arch>
```

### Building
## Building

To build the project, run `script/build`. After building, you can run the binary locally, for example:
`./gh-models list`.

### Testing
## Testing

To run lint tests, unit tests, and other Go-related checks before submitting a pull request, use:

Expand All @@ -34,7 +34,7 @@ make vet # to find suspicious constructs
make tidy # to keep dependencies up-to-date
```

### Releasing
## Releasing

When upgrading or installing the extension using `gh extension upgrade github/gh-models` or
`gh extension install github/gh-models`, the latest release will be pulled, not the latest commit. Therefore, all
Expand Down
15 changes: 15 additions & 0 deletions cmd/list/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,19 @@ func TestList(t *testing.T) {
require.Contains(t, output, modelSummary.FriendlyName)
require.Contains(t, output, modelSummary.Name)
})

t.Run("--help prints usage info", func(t *testing.T) {
outBuf := new(bytes.Buffer)
errBuf := new(bytes.Buffer)
listCmd := NewListCommand(nil)
listCmd.SetOut(outBuf)
listCmd.SetErr(errBuf)
listCmd.SetArgs([]string{"--help"})

err := listCmd.Help()

require.NoError(t, err)
require.Contains(t, outBuf.String(), "List available models")
require.Empty(t, errBuf.String())
})
}
7 changes: 6 additions & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ func NewRootCommand() *cobra.Command {
util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n")
client = azuremodels.NewUnauthenticatedClient()
} else {
client = azuremodels.NewAzureClient(token)
var err error
client, err = azuremodels.NewDefaultAzureClient(token)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting together an AzureClient requires getting an http.Client, which could error. We're no longer silencing that error if it occurs, which helped me figure out a panic occurring on CI when running tests.

if err != nil {
util.WriteToOut(terminal.ErrOut(), "Error creating Azure client: "+err.Error())
return nil
}
}

cfg := command.NewConfigWithTerminal(terminal, client)
Expand Down
26 changes: 26 additions & 0 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package cmd

import (
"bytes"
"regexp"
"testing"

"github.com/stretchr/testify/require"
)

func TestRoot(t *testing.T) {
t.Run("usage info describes sub-commands", func(t *testing.T) {
buf := new(bytes.Buffer)
rootCmd := NewRootCommand()
rootCmd.SetOut(buf)

err := rootCmd.Help()

require.NoError(t, err)
output := buf.String()
require.Regexp(t, regexp.MustCompile(`Usage:\n\s+gh models \[command\]`), output)
require.Regexp(t, regexp.MustCompile(`list\s+List available models`), output)
require.Regexp(t, regexp.MustCompile(`run\s+Run inference with the specified model`), output)
require.Regexp(t, regexp.MustCompile(`view\s+View details about a model`), output)
})
}
21 changes: 21 additions & 0 deletions cmd/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package run
import (
"bytes"
"context"
"regexp"
"testing"

"github.com/github/gh-models/internal/azuremodels"
Expand Down Expand Up @@ -59,4 +60,24 @@ func TestRun(t *testing.T) {
output := buf.String()
require.Contains(t, output, fakeMessageFromModel)
})

t.Run("--help prints usage info", func(t *testing.T) {
outBuf := new(bytes.Buffer)
errBuf := new(bytes.Buffer)
runCmd := NewRunCommand(nil)
runCmd.SetOut(outBuf)
runCmd.SetErr(errBuf)
runCmd.SetArgs([]string{"--help"})

err := runCmd.Help()

require.NoError(t, err)
output := outBuf.String()
require.Contains(t, output, "Run inference with the specified model")
require.Regexp(t, regexp.MustCompile(`--max-tokens string\s+Limit the maximum tokens for the model response\.`), output)
require.Regexp(t, regexp.MustCompile(`--system-prompt string\s+Prompt the system\.`), output)
require.Regexp(t, regexp.MustCompile(`--temperature string\s+Controls randomness in the response, use lower to be more deterministic\.`), output)
require.Regexp(t, regexp.MustCompile(`--top-p string\s+Controls text diversity by selecting the most probable words until a set probability is reached\.`), output)
require.Empty(t, errBuf.String())
})
}
15 changes: 15 additions & 0 deletions cmd/view/view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,19 @@ func TestView(t *testing.T) {
require.Contains(t, output, "Evaluation:")
require.Contains(t, output, modelDetails.Evaluation)
})

t.Run("--help prints usage info", func(t *testing.T) {
outBuf := new(bytes.Buffer)
errBuf := new(bytes.Buffer)
viewCmd := NewViewCommand(nil)
viewCmd.SetOut(outBuf)
viewCmd.SetErr(errBuf)
viewCmd.SetArgs([]string{"--help"})

err := viewCmd.Help()

require.NoError(t, err)
require.Contains(t, outBuf.String(), "View details about a model")
require.Empty(t, errBuf.String())
})
}
31 changes: 16 additions & 15 deletions internal/azuremodels/azure_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,22 @@ import (
type AzureClient struct {
client *http.Client
token string
cfg *AzureClientConfig
}

const (
prodInferenceURL = "https://models.inference.ai.azure.com/chat/completions"
azureAiStudioURL = "https://api.catalog.azureml.ms"
prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models"
)

// NewAzureClient returns a new Azure client using the given auth token.
func NewAzureClient(authToken string) *AzureClient {
httpClient, _ := api.DefaultHTTPClient()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the error we were swallowing before.

return &AzureClient{
client: httpClient,
token: authToken,
// NewDefaultAzureClient returns a new Azure client using the given auth token using default API URLs.
func NewDefaultAzureClient(authToken string) (*AzureClient, error) {
httpClient, err := api.DefaultHTTPClient()
if err != nil {
return nil, err
}
cfg := NewDefaultAzureClientConfig()
return &AzureClient{client: httpClient, token: authToken, cfg: cfg}, nil
}

// NewAzureClient returns a new Azure client using the given HTTP client, configuration, and auth token.
func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientConfig) *AzureClient {
return &AzureClient{client: httpClient, token: authToken, cfg: cfg}
}

// GetChatCompletionStream returns a stream of chat completions using the given options.
Expand All @@ -54,7 +55,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl

body := bytes.NewReader(bodyBytes)

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodInferenceURL, body)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.InferenceURL, body)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -99,7 +100,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl

// GetModelDetails returns the details of the specified model in a particular registry.
func (c *AzureClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) {
url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version)
url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", c.cfg.AzureAiStudioURL, registry, modelName, version)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, err
Expand Down Expand Up @@ -189,7 +190,7 @@ func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) {
}
`))

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodModelsURL, body)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.ModelsURL, body)
if err != nil {
return nil, err
}
Expand Down
23 changes: 23 additions & 0 deletions internal/azuremodels/azure_client_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package azuremodels

const (
defaultInferenceURL = "https://models.inference.ai.azure.com/chat/completions"
defaultAzureAiStudioURL = "https://api.catalog.azureml.ms"
defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models"
)

// AzureClientConfig represents configurable settings for the Azure client.
type AzureClientConfig struct {
InferenceURL string
AzureAiStudioURL string
ModelsURL string
}

// NewDefaultAzureClientConfig returns a new AzureClientConfig with default values for API URLs.
func NewDefaultAzureClientConfig() *AzureClientConfig {
return &AzureClientConfig{
InferenceURL: defaultInferenceURL,
AzureAiStudioURL: defaultAzureAiStudioURL,
ModelsURL: defaultModelsURL,
}
}
68 changes: 68 additions & 0 deletions internal/azuremodels/azure_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package azuremodels

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"
)

func TestAzureClient(t *testing.T) {
ctx := context.Background()

t.Run("GetModelDetails happy path", func(t *testing.T) {
registry := "foo"
modelName := "bar"
version := "baz"
textLimits := &modelCatalogTextLimits{MaxOutputTokens: 8675309, InputContextWindow: 3}
modelLimits := &modelCatalogLimits{
SupportedInputModalities: []string{"books", "VHS"},
SupportedOutputModalities: []string{"watercolor paintings"},
SupportedLanguages: []string{"fr", "zh"},
TextLimits: textLimits,
}
playgroundLimits := &modelCatalogPlaygroundLimits{RateLimitTier: "big-ish"}
catalogDetails := &modelCatalogDetailsResponse{
Description: "some model description",
License: "My Favorite License",
LicenseDescription: "This is a test license",
Notes: "You aren't gonna believe these notes.",
Keywords: []string{"Tag1", "TAG2"},
Evaluation: "This model is the best.",
ModelLimits: modelLimits,
PlaygroundLimits: playgroundLimits,
}
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "application/json", r.Header.Get("Content-Type"))
require.Equal(t, "/asset-gallery/v1.0/"+registry+"/models/"+modelName+"/version/"+version, r.URL.Path)

w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(catalogDetails)
require.NoError(t, err)
}))
defer testServer.Close()
cfg := &AzureClientConfig{AzureAiStudioURL: testServer.URL}
httpClient := testServer.Client()
client := NewAzureClient(httpClient, "fake-token-123abc", cfg)

details, err := client.GetModelDetails(ctx, registry, modelName, version)

require.NoError(t, err)
require.NotNil(t, details)
require.Equal(t, catalogDetails.Description, details.Description)
require.Equal(t, catalogDetails.License, details.License)
require.Equal(t, catalogDetails.LicenseDescription, details.LicenseDescription)
require.Equal(t, catalogDetails.Notes, details.Notes)
require.Equal(t, []string{"tag1", "tag2"}, details.Tags)
require.Equal(t, catalogDetails.Evaluation, details.Evaluation)
require.Equal(t, modelLimits.SupportedInputModalities, details.SupportedInputModalities)
require.Equal(t, modelLimits.SupportedOutputModalities, details.SupportedOutputModalities)
require.Equal(t, []string{"French", "Chinese"}, details.SupportedLanguages)
require.Equal(t, textLimits.MaxOutputTokens, details.MaxOutputTokens)
require.Equal(t, textLimits.InputContextWindow, details.MaxInputTokens)
require.Equal(t, playgroundLimits.RateLimitTier, details.RateLimitTier)
})
}
Loading
Loading