-
Notifications
You must be signed in to change notification settings - Fork 16
Add more test coverage, stop swallowing client error #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a71b2a7
b9acd54
2574762
e8f27c8
5788fd5
d1d835b
0e0d495
61eea86
13cfac3
ef298fc
ece6efe
17ae765
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,7 @@ jobs: | |
| - uses: actions/checkout@v4 | ||
| - uses: actions/setup-go@v5 | ||
| with: | ||
| go-version: ${{ vars.GOVERSION }} | ||
| go-version: ">=1.22" | ||
| check-latest: true | ||
| - name: Verify go.sum is up to date | ||
| run: | | ||
|
|
@@ -43,4 +43,6 @@ jobs: | |
| run: go build -v ./... | ||
|
|
||
| - name: Run tests | ||
| run: go test -race -cover ./... | ||
| run: | | ||
| go version | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Outputting the version for visibility when looking at test results, in case the version affects the results. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice improvement! We could include this in a |
||
| go test -race -cover ./... | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,12 @@ func NewRootCommand() *cobra.Command { | |
| util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") | ||
| client = azuremodels.NewUnauthenticatedClient() | ||
| } else { | ||
| client = azuremodels.NewAzureClient(token) | ||
| var err error | ||
| client, err = azuremodels.NewDefaultAzureClient(token) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Putting together an |
||
| if err != nil { | ||
| util.WriteToOut(terminal.ErrOut(), "Error creating Azure client: "+err.Error()) | ||
| return nil | ||
| } | ||
| } | ||
|
|
||
| cfg := command.NewConfigWithTerminal(terminal, client) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| package cmd | ||
|
|
||
| import ( | ||
| "bytes" | ||
| "regexp" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestRoot(t *testing.T) { | ||
| t.Run("usage info describes sub-commands", func(t *testing.T) { | ||
| buf := new(bytes.Buffer) | ||
| rootCmd := NewRootCommand() | ||
| rootCmd.SetOut(buf) | ||
|
|
||
| err := rootCmd.Help() | ||
|
|
||
| require.NoError(t, err) | ||
| output := buf.String() | ||
| require.Regexp(t, regexp.MustCompile(`Usage:\n\s+gh models \[command\]`), output) | ||
| require.Regexp(t, regexp.MustCompile(`list\s+List available models`), output) | ||
| require.Regexp(t, regexp.MustCompile(`run\s+Run inference with the specified model`), output) | ||
| require.Regexp(t, regexp.MustCompile(`view\s+View details about a model`), output) | ||
| }) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,21 +21,22 @@ import ( | |
| type AzureClient struct { | ||
| client *http.Client | ||
| token string | ||
| cfg *AzureClientConfig | ||
| } | ||
|
|
||
| const ( | ||
| prodInferenceURL = "https://models.inference.ai.azure.com/chat/completions" | ||
| azureAiStudioURL = "https://api.catalog.azureml.ms" | ||
| prodModelsURL = azureAiStudioURL + "/asset-gallery/v1.0/models" | ||
| ) | ||
|
|
||
| // NewAzureClient returns a new Azure client using the given auth token. | ||
| func NewAzureClient(authToken string) *AzureClient { | ||
| httpClient, _ := api.DefaultHTTPClient() | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the error we were swallowing before. |
||
| return &AzureClient{ | ||
| client: httpClient, | ||
| token: authToken, | ||
| // NewDefaultAzureClient returns a new Azure client using the given auth token using default API URLs. | ||
| func NewDefaultAzureClient(authToken string) (*AzureClient, error) { | ||
| httpClient, err := api.DefaultHTTPClient() | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| cfg := NewDefaultAzureClientConfig() | ||
| return &AzureClient{client: httpClient, token: authToken, cfg: cfg}, nil | ||
| } | ||
|
|
||
| // NewAzureClient returns a new Azure client using the given HTTP client, configuration, and auth token. | ||
| func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientConfig) *AzureClient { | ||
| return &AzureClient{client: httpClient, token: authToken, cfg: cfg} | ||
| } | ||
|
|
||
| // GetChatCompletionStream returns a stream of chat completions using the given options. | ||
|
|
@@ -54,7 +55,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl | |
|
|
||
| body := bytes.NewReader(bodyBytes) | ||
|
|
||
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodInferenceURL, body) | ||
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.InferenceURL, body) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
@@ -99,7 +100,7 @@ func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompl | |
|
|
||
| // GetModelDetails returns the details of the specified model in a particular registry. | ||
| func (c *AzureClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { | ||
| url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", azureAiStudioURL, registry, modelName, version) | ||
| url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", c.cfg.AzureAiStudioURL, registry, modelName, version) | ||
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) | ||
| if err != nil { | ||
| return nil, err | ||
|
|
@@ -189,7 +190,7 @@ func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { | |
| } | ||
| `)) | ||
|
|
||
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, prodModelsURL, body) | ||
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.ModelsURL, body) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| package azuremodels | ||
|
|
||
| const ( | ||
| defaultInferenceURL = "https://models.inference.ai.azure.com/chat/completions" | ||
| defaultAzureAiStudioURL = "https://api.catalog.azureml.ms" | ||
| defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models" | ||
| ) | ||
|
|
||
| // AzureClientConfig represents configurable settings for the Azure client. | ||
| type AzureClientConfig struct { | ||
| InferenceURL string | ||
| AzureAiStudioURL string | ||
| ModelsURL string | ||
| } | ||
|
|
||
| // NewDefaultAzureClientConfig returns a new AzureClientConfig with default values for API URLs. | ||
| func NewDefaultAzureClientConfig() *AzureClientConfig { | ||
| return &AzureClientConfig{ | ||
| InferenceURL: defaultInferenceURL, | ||
| AzureAiStudioURL: defaultAzureAiStudioURL, | ||
| ModelsURL: defaultModelsURL, | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| package azuremodels | ||
|
|
||
| import ( | ||
| "context" | ||
| "encoding/json" | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestAzureClient(t *testing.T) { | ||
| ctx := context.Background() | ||
|
|
||
| t.Run("GetModelDetails happy path", func(t *testing.T) { | ||
| registry := "foo" | ||
| modelName := "bar" | ||
| version := "baz" | ||
| textLimits := &modelCatalogTextLimits{MaxOutputTokens: 8675309, InputContextWindow: 3} | ||
| modelLimits := &modelCatalogLimits{ | ||
| SupportedInputModalities: []string{"books", "VHS"}, | ||
| SupportedOutputModalities: []string{"watercolor paintings"}, | ||
| SupportedLanguages: []string{"fr", "zh"}, | ||
| TextLimits: textLimits, | ||
| } | ||
| playgroundLimits := &modelCatalogPlaygroundLimits{RateLimitTier: "big-ish"} | ||
| catalogDetails := &modelCatalogDetailsResponse{ | ||
| Description: "some model description", | ||
| License: "My Favorite License", | ||
| LicenseDescription: "This is a test license", | ||
| Notes: "You aren't gonna believe these notes.", | ||
| Keywords: []string{"Tag1", "TAG2"}, | ||
| Evaluation: "This model is the best.", | ||
| ModelLimits: modelLimits, | ||
| PlaygroundLimits: playgroundLimits, | ||
| } | ||
| testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| require.Equal(t, "application/json", r.Header.Get("Content-Type")) | ||
| require.Equal(t, "/asset-gallery/v1.0/"+registry+"/models/"+modelName+"/version/"+version, r.URL.Path) | ||
|
|
||
| w.WriteHeader(http.StatusOK) | ||
| err := json.NewEncoder(w).Encode(catalogDetails) | ||
| require.NoError(t, err) | ||
| })) | ||
| defer testServer.Close() | ||
| cfg := &AzureClientConfig{AzureAiStudioURL: testServer.URL} | ||
| httpClient := testServer.Client() | ||
| client := NewAzureClient(httpClient, "fake-token-123abc", cfg) | ||
|
|
||
| details, err := client.GetModelDetails(ctx, registry, modelName, version) | ||
|
|
||
| require.NoError(t, err) | ||
| require.NotNil(t, details) | ||
| require.Equal(t, catalogDetails.Description, details.Description) | ||
| require.Equal(t, catalogDetails.License, details.License) | ||
| require.Equal(t, catalogDetails.LicenseDescription, details.LicenseDescription) | ||
| require.Equal(t, catalogDetails.Notes, details.Notes) | ||
| require.Equal(t, []string{"tag1", "tag2"}, details.Tags) | ||
| require.Equal(t, catalogDetails.Evaluation, details.Evaluation) | ||
| require.Equal(t, modelLimits.SupportedInputModalities, details.SupportedInputModalities) | ||
| require.Equal(t, modelLimits.SupportedOutputModalities, details.SupportedOutputModalities) | ||
| require.Equal(t, []string{"French", "Chinese"}, details.SupportedLanguages) | ||
| require.Equal(t, textLimits.MaxOutputTokens, details.MaxOutputTokens) | ||
| require.Equal(t, textLimits.InputContextWindow, details.MaxInputTokens) | ||
| require.Equal(t, playgroundLimits.RateLimitTier, details.RateLimitTier) | ||
| }) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifying the same Go version here as we do in our dev docs.