From f94b7a75ba7c1cef53da38587b77e838248431c7 Mon Sep 17 00:00:00 2001 From: Stainless Bot Date: Thu, 1 Aug 2024 11:05:04 +0000 Subject: [PATCH] chore(internal): updates --- README.md | 59 ++------- azure/azure.go | 237 ---------------------------------- azure/azure_test.go | 130 ------------------- azure/example_auth_test.go | 47 ------- chatcompletion.go | 92 ------------- go.mod | 24 +--- go.sum | 31 +---- internal/apijson/decoder.go | 121 +++++++++-------- internal/apijson/json_test.go | 53 +++++++- 9 files changed, 135 insertions(+), 659 deletions(-) delete mode 100644 azure/azure.go delete mode 100644 azure/azure_test.go delete mode 100644 azure/example_auth_test.go diff --git a/README.md b/README.md index 0397a061..87464506 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,10 @@ func main() { option.WithAPIKey("My API Key"), // defaults to os.LookupEnv("OPENAI_API_KEY") ) chatCompletion, err := client.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), - }), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Say this is a test")), + }}), Model: openai.F(openai.ChatModelGPT4o), }) if err != nil { @@ -236,9 +237,10 @@ defer cancel() client.Chat.Completions.New( ctx, openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), - }), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("How can I list all files in a directory using Python?")), + }}), Model: openai.F(openai.ChatModelGPT4o), }, // This sets the per-retry timeout @@ -298,9 +300,10 @@ client := openai.NewClient( client.Chat.Completions.New( context.TODO(), openai.ChatCompletionNewParams{ - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), - }), + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("How can I get the name of the current day in Node.js?")), + }}), Model: openai.F(openai.ChatModelGPT4o), }, option.WithMaxRetries(5), @@ -393,44 +396,6 @@ You may also replace the default `http.Client` with accepted (this overwrites any previous client) and receives requests after any middleware has been applied. -## Microsoft Azure OpenAI - -To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the option.RequestOption functions in the `azure` package. - -```go -package main - -import ( - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/openai/openai-go" - "github.com/openai/openai-go/azure" - "github.com/openai/openai-go/option" -) - -func main() { - const azureOpenAIEndpoint = "https://.openai.azure.com" - - // The latest API versions, including previews, can be found here: - // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning - const azureOpenAIAPIVersion = "2024-06-01" - - tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) - - if err != nil { - fmt.Printf("Failed to create the DefaultAzureCredential: %s", err) - os.Exit(1) - } - - client := openai.NewClient( - azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), - - // Choose between authenticating using a TokenCredential or an API Key - azure.WithTokenCredential(tokenCredential), - // or azure.WithAPIKey(azureOpenAIAPIKey), - ) -} -``` - ## Semantic versioning This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: diff --git a/azure/azure.go b/azure/azure.go deleted file mode 100644 index 5d3156fc..00000000 --- a/azure/azure.go +++ /dev/null @@ -1,237 +0,0 @@ -// Package azure provides configuration options so you can connect and use Azure OpenAI using the [openai.Client]. -// -// Typical usage of this package will look like this: -// -// client := openai.NewClient( -// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), -// azure.WithTokenCredential(azureIdentityTokenCredential), -// // or azure.WithAPIKey(azureOpenAIAPIKey), -// ) -// -// Or, if you want to construct a specific service: -// -// client := openai.NewChatCompletionService( -// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), -// azure.WithTokenCredential(azureIdentityTokenCredential), -// // or azure.WithAPIKey(azureOpenAIAPIKey), -// ) -package azure - -import ( - "bytes" - "encoding/json" - "errors" - "io" - "mime" - "mime/multipart" - "net/http" - "net/url" - "strings" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" - "github.com/openai/openai-go/internal/requestconfig" - "github.com/openai/openai-go/option" -) - -// WithEndpoint configures this client to connect to an Azure OpenAI endpoint. -// -// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://.openai.azure.com -// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty. -// -// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this: -// -// client := openai.NewClient( -// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), -// azure.WithTokenCredential(azureIdentityTokenCredential), -// // or azure.WithAPIKey(azureOpenAIAPIKey), -// ) -// -// [Azure OpenAI apiversions]: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning -func WithEndpoint(endpoint string, apiVersion string) option.RequestOption { - if !strings.HasSuffix(endpoint, "/") { - endpoint += "/" - } - - endpoint += "openai/" - - withQueryAdd := option.WithQueryAdd("api-version", apiVersion) - withEndpoint := option.WithBaseURL(endpoint) - - withModelMiddleware := option.WithMiddleware(func(r *http.Request, mn option.MiddlewareNext) (*http.Response, error) { - replacementPath, err := getReplacementPathWithDeployment(r) - - if err != nil { - return nil, err - } - - r.URL.Path = replacementPath - return mn(r) - }) - - return func(rc *requestconfig.RequestConfig) error { - if apiVersion == "" { - return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.") - } - - if err := withQueryAdd(rc); err != nil { - return err - } - - if err := withEndpoint(rc); err != nil { - return err - } - - if err := withModelMiddleware(rc); err != nil { - return err - } - - return nil - } -} - -// WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential. -// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance. -// -// [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity -func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption { - bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil) - - // add in a middleware that uses the bearer token generated from the token credential - return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { - pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{ - InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc.. - PerRetryPolicies: []policy.Policy{ - bearerTokenPolicy, - policyAdapter(next), - }, - }) - - req2, err := runtime.NewRequestFromRequest(req) - - if err != nil { - return nil, err - } - - return pipeline.Do(req2) - }) -} - -// WithAPIKey configures this client to authenticate using an API key. -// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance. -func WithAPIKey(apiKey string) option.RequestOption { - // NOTE: there is an option.WithApiKey(), but that adds the value into - // the Authorization header instead so we're doing this instead. - return option.WithHeader("Api-Key", apiKey) -} - -// jsonRoutes have JSON payloads - we'll deserialize looking for a .model field in there -// so we won't have to worry about individual types for completions vs embeddings, etc... -var jsonRoutes = map[string]bool{ - "/openai/completions": true, - "/openai/chat/completions": true, - "/openai/embeddings": true, - "/openai/audio/speech": true, - "/openai/images/generations": true, -} - -// audioMultipartRoutes have mime/multipart payloads. These are less generic - we're very much -// expecting a transcription or translation payload for these. -var audioMultipartRoutes = map[string]bool{ - "/openai/audio/transcriptions": true, - "/openai/audio/translations": true, -} - -// getReplacementPathWithDeployment parses the request body to extract out the Model parameter (or equivalent) -// (note, the req.Body is fully read as part of this, and is replaced with a bytes.Reader) -func getReplacementPathWithDeployment(req *http.Request) (string, error) { - if jsonRoutes[req.URL.Path] { - return getJSONRoute(req) - } - - if audioMultipartRoutes[req.URL.Path] { - return getAudioMultipartRoute(req) - } - - // No need to relocate the path. We've already tacked on /openai when we setup the endpoint. - return req.URL.Path, nil -} - -func getJSONRoute(req *http.Request) (string, error) { - // we need to deserialize the body, partly, in order to read out the model field. - jsonBytes, err := io.ReadAll(req.Body) - - if err != nil { - return "", err - } - - // make sure we restore the body so it can be used in later middlewares. - req.Body = io.NopCloser(bytes.NewReader(jsonBytes)) - - var v *struct { - Model string `json:"model"` - } - - if err := json.Unmarshal(jsonBytes, &v); err != nil { - return "", err - } - - escapedDeployment := url.PathEscape(v.Model) - return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil -} - -func getAudioMultipartRoute(req *http.Request) (string, error) { - // body is a multipart/mime body type instead. - mimeBytes, err := io.ReadAll(req.Body) - - if err != nil { - return "", err - } - - // make sure we restore the body so it can be used in later middlewares. - req.Body = io.NopCloser(bytes.NewReader(mimeBytes)) - - _, mimeParams, err := mime.ParseMediaType(req.Header.Get("Content-Type")) - - if err != nil { - return "", err - } - - mimeReader := multipart.NewReader( - io.NopCloser(bytes.NewReader(mimeBytes)), - mimeParams["boundary"]) - - for { - mp, err := mimeReader.NextPart() - - if err != nil { - if errors.Is(err, io.EOF) { - return "", errors.New("unable to find the model part in multipart body") - } - - return "", err - } - - defer mp.Close() - - if mp.FormName() == "model" { - modelBytes, err := io.ReadAll(mp) - - if err != nil { - return "", err - } - - escapedDeployment := url.PathEscape(string(modelBytes)) - return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil - } - } -} - -type policyAdapter option.MiddlewareNext - -func (mp policyAdapter) Do(req *policy.Request) (*http.Response, error) { - return (option.MiddlewareNext)(mp)(req.Raw()) -} - -const version = "v.0.1.0" diff --git a/azure/azure_test.go b/azure/azure_test.go deleted file mode 100644 index 00f57331..00000000 --- a/azure/azure_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package azure - -import ( - "bytes" - "mime/multipart" - "net/http" - "testing" - - "github.com/openai/openai-go" - "github.com/openai/openai-go/internal/apijson" - "github.com/openai/openai-go/shared" -) - -func TestJSONRoute(t *testing.T) { - chatCompletionParams := openai.ChatCompletionNewParams{ - Model: openai.F(openai.ChatModel("arbitraryDeployment")), - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.ChatCompletionAssistantMessageParam{ - Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), - Content: openai.String("You are a helpful assistant"), - }, - openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), - }, - }), - } - - serializedBytes, err := apijson.MarshalRoot(chatCompletionParams) - - if err != nil { - t.Fatal(err) - } - - req, err := http.NewRequest("POST", "/openai/chat/completions", bytes.NewReader(serializedBytes)) - - if err != nil { - t.Fatal(err) - } - - replacementPath, err := getReplacementPathWithDeployment(req) - - if err != nil { - t.Fatal(err) - } - - if replacementPath != "/openai/deployments/arbitraryDeployment/chat/completions" { - t.Fatalf("replacementpath didn't match: %s", replacementPath) - } -} - -func TestGetAudioMultipartRoute(t *testing.T) { - buff := &bytes.Buffer{} - mw := multipart.NewWriter(buff) - defer mw.Close() - - fw, err := mw.CreateFormFile("file", "test.mp3") - - if err != nil { - t.Fatal(err) - } - - if _, err = fw.Write([]byte("ignore me")); err != nil { - t.Fatal(err) - } - - if err := mw.WriteField("model", "arbitraryDeployment"); err != nil { - t.Fatal(err) - } - - if err := mw.Close(); err != nil { - t.Fatal(err) - } - - req, err := http.NewRequest("POST", "/openai/audio/transcriptions", bytes.NewReader(buff.Bytes())) - - if err != nil { - t.Fatal(err) - } - - req.Header.Set("Content-Type", mw.FormDataContentType()) - - replacementPath, err := getReplacementPathWithDeployment(req) - - if err != nil { - t.Fatal(err) - } - - if replacementPath != "/openai/deployments/arbitraryDeployment/audio/transcriptions" { - t.Fatalf("replacementpath didn't match: %s", replacementPath) - } -} - -func TestNoRouteChangeNeeded(t *testing.T) { - chatCompletionParams := openai.ChatCompletionNewParams{ - Model: openai.F(openai.ChatModel("arbitraryDeployment")), - Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ - openai.ChatCompletionAssistantMessageParam{ - Role: openai.F(openai.ChatCompletionAssistantMessageParamRoleAssistant), - Content: openai.String("You are a helpful assistant"), - }, - openai.ChatCompletionUserMessageParam{ - Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), - Content: openai.F[openai.ChatCompletionUserMessageParamContentUnion](shared.UnionString("Can you tell me another word for the universe?")), - }, - }), - } - - serializedBytes, err := apijson.MarshalRoot(chatCompletionParams) - - if err != nil { - t.Fatal(err) - } - - req, err := http.NewRequest("POST", "/openai/does/not/need/a/deployment", bytes.NewReader(serializedBytes)) - - if err != nil { - t.Fatal(err) - } - - replacementPath, err := getReplacementPathWithDeployment(req) - - if err != nil { - t.Fatal(err) - } - - if replacementPath != "/openai/does/not/need/a/deployment" { - t.Fatalf("replacementpath didn't match: %s", replacementPath) - } -} diff --git a/azure/example_auth_test.go b/azure/example_auth_test.go deleted file mode 100644 index 3a8ef214..00000000 --- a/azure/example_auth_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package azure_test - -import ( - "fmt" - - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/openai/openai-go" - "github.com/openai/openai-go/azure" -) - -func Example_authentication() { - // There are two ways to authenticate - using a TokenCredential (via the azidentity - // package), or using an API Key. - const azureOpenAIEndpoint = "https://.openai.azure.com" - const azureOpenAIAPIVersion = "" - - // Using a TokenCredential - { - // For a full list of credential types look at the documentation for the Azure Identity - // package: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity - tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) - - if err != nil { - fmt.Printf("Failed to create TokenCredential: %s\n", err) - return - } - - client := openai.NewClient( - azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), - azure.WithTokenCredential(tokenCredential), - ) - - _ = client - } - - // Using an API Key - { - const azureOpenAIAPIKey = "" - - client := openai.NewClient( - azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), - azure.WithAPIKey(azureOpenAIAPIKey), - ) - - _ = client - } -} diff --git a/chatcompletion.go b/chatcompletion.go index f6331cbf..4dcd393e 100644 --- a/chatcompletion.go +++ b/chatcompletion.go @@ -12,73 +12,8 @@ import ( "github.com/openai/openai-go/option" "github.com/openai/openai-go/packages/ssestream" "github.com/openai/openai-go/shared" - "github.com/tidwall/sjson" ) -func UserMessage(content string) ChatCompletionMessageParamUnion { - return ChatCompletionUserMessageParam{ - Role: F(ChatCompletionUserMessageParamRoleUser), - Content: F[ChatCompletionUserMessageParamContentUnion]( - shared.UnionString(content), - ), - } -} - -func UserMessageBlocks(blocks ...ChatCompletionContentPartUnionParam) ChatCompletionMessageParamUnion { - return ChatCompletionUserMessageParam{ - Role: F(ChatCompletionUserMessageParamRoleUser), - Content: F[ChatCompletionUserMessageParamContentUnion]( - ChatCompletionUserMessageParamContentArrayOfContentParts(blocks), - ), - } -} - -func UserMessageTextBlock(content string) ChatCompletionContentPartUnionParam { - return ChatCompletionContentPartTextParam{ - Type: F(ChatCompletionContentPartTextTypeText), - Text: F(content), - } -} - -func UserMessageImageBlock(url string) ChatCompletionContentPartUnionParam { - return ChatCompletionContentPartImageParam{ - Type: F(ChatCompletionContentPartImageTypeImageURL), - ImageURL: F(ChatCompletionContentPartImageImageURLParam{ - URL: F(url), - }), - } -} - -func AssistantMessage(content string) ChatCompletionMessageParamUnion { - return ChatCompletionAssistantMessageParam{ - Role: F(ChatCompletionAssistantMessageParamRoleAssistant), - Content: F(content), - } -} - -func ToolMessage(toolCallID, content string) ChatCompletionMessageParamUnion { - return ChatCompletionToolMessageParam{ - Role: F(ChatCompletionToolMessageParamRoleTool), - ToolCallID: F(toolCallID), - Content: F(content), - } -} - -func SystemMessage(content string) ChatCompletionMessageParamUnion { - return ChatCompletionSystemMessageParam{ - Role: F(ChatCompletionSystemMessageParamRoleSystem), - Content: F(content), - } -} - -func FunctionMessage(name, content string) ChatCompletionMessageParamUnion { - return ChatCompletionFunctionMessageParam{ - Role: F(ChatCompletionFunctionMessageParamRoleFunction), - Name: F(name), - Content: F(content), - } -} - // ChatCompletionService contains methods and other services that help with // interacting with the openai API. // @@ -850,35 +785,10 @@ func (r *ChatCompletionMessage) UnmarshalJSON(data []byte) (err error) { return apijson.UnmarshalRoot(data, r) } -func (r ChatCompletionMessage) MarshalJSON() (data []byte, err error) { - s := "" - s, _ = sjson.Set(s, "role", r.Role) - - if r.FunctionCall.Name != "" { - b, err := apijson.Marshal(r.FunctionCall) - if err != nil { - return nil, err - } - s, _ = sjson.SetRaw(s, "function_call", string(b)) - } else if len(r.ToolCalls) > 0 { - b, err := apijson.Marshal(r.ToolCalls) - if err != nil { - return nil, err - } - s, _ = sjson.SetRaw(s, "tool_calls", string(b)) - } else { - s, _ = sjson.Set(s, "content", r.Content) - } - - return []byte(s), nil -} - func (r chatCompletionMessageJSON) RawJSON() string { return r.raw } -func (r ChatCompletionMessage) implementsChatCompletionMessageParamUnion() {} - // The role of the author of this message. type ChatCompletionMessageRole string @@ -947,8 +857,6 @@ func (r ChatCompletionMessageParam) implementsChatCompletionMessageParamUnion() // [ChatCompletionUserMessageParam], [ChatCompletionAssistantMessageParam], // [ChatCompletionToolMessageParam], [ChatCompletionFunctionMessageParam], // [ChatCompletionMessageParam]. -// -// This union is additionally satisfied by the return types [ChatCompletionMessage] type ChatCompletionMessageParamUnion interface { implementsChatCompletionMessageParamUnion() } diff --git a/go.mod b/go.mod index a487ea3d..1e064e63 100644 --- a/go.mod +++ b/go.mod @@ -3,27 +3,9 @@ module github.com/openai/openai-go go 1.19 require ( - github.com/tidwall/gjson v1.14.4 - github.com/tidwall/sjson v1.2.5 -) - -require ( - github.com/google/uuid v1.6.0 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect -) - -require ( - // NOTE: these dependencies are only used for the `azure` subpackage. - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 - github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect - github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/kylelemons/godebug v1.1.0 // indirect - github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect - golang.org/x/crypto v0.25.0 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.16.0 // indirect + github.com/tidwall/sjson v1.2.5 // indirect ) diff --git a/go.sum b/go.sum index 240415f5..569e555a 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,5 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= -github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= -github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= -github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -27,13 +10,3 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/apijson/decoder.go b/internal/apijson/decoder.go index deb0bac6..e1b21b7a 100644 --- a/internal/apijson/decoder.go +++ b/internal/apijson/decoder.go @@ -214,15 +214,29 @@ func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc { decoders = append(decoders, decoder) } return func(n gjson.Result, v reflect.Value, state *decoderState) error { - // Set bestExactness to worse than loose - bestExactness := loose - 1 - + // If there is a discriminator match, circumvent the exactness logic entirely for idx, variant := range unionEntry.variants { decoder := decoders[idx] if variant.TypeFilter != n.Type { continue } - if len(unionEntry.discriminatorKey) != 0 && n.Get(unionEntry.discriminatorKey).Value() != variant.DiscriminatorValue { + + if len(unionEntry.discriminatorKey) != 0 { + discriminatorValue := n.Get(unionEntry.discriminatorKey).Value() + if discriminatorValue == variant.DiscriminatorValue { + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, state) + v.Set(inner) + return err + } + } + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { continue } sub := decoderState{strict: state.strict, exactness: exact} @@ -325,62 +339,58 @@ func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc { func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { // map of json field name to struct field decoders decoderFields := map[string]decoderField{} + anonymousDecoders := []decoderField{} extraDecoder := (*decoderField)(nil) inlineDecoder := (*decoderField)(nil) - // This helper allows us to recursively collect field encoders into a flat - // array. The parameter `index` keeps track of the access patterns necessary - // to get to some field. - var collectFieldDecoders func(r reflect.Type, index []int) - collectFieldDecoders = func(r reflect.Type, index []int) { - for i := 0; i < r.NumField(); i++ { - idx := append(index, i) - field := t.FieldByIndex(idx) - if !field.IsExported() { - continue - } - // If this is an embedded struct, traverse one level deeper to extract - // the fields and get their encoders as well. - if field.Anonymous { - collectFieldDecoders(field.Type, idx) - continue - } - // If json tag is not present, then we skip, which is intentionally - // different behavior from the stdlib. - ptag, ok := parseJSONStructTag(field) - if !ok { - continue - } - // We only want to support unexported fields if they're tagged with - // `extras` because that field shouldn't be part of the public API. We - // also want to only keep the top level extras - if ptag.extras && len(index) == 0 { - extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} - continue - } - if ptag.inline && len(index) == 0 { - inlineDecoder = &decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} - continue - } - if ptag.metadata { - continue - } + for i := 0; i < t.NumField(); i++ { + idx := []int{i} + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the fields and get their encoders as well. + if field.Anonymous { + anonymousDecoders = append(anonymousDecoders, decoderField{ + fn: d.typeDecoder(field.Type), + idx: idx[:], + }) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported fields if they're tagged with + // `extras` because that field shouldn't be part of the public API. + if ptag.extras { + extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} + continue + } + if ptag.inline { + inlineDecoder = &decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + continue + } + if ptag.metadata { + continue + } - oldFormat := d.dateFormat - dateFormat, ok := parseFormatStructTag(field) - if ok { - switch dateFormat { - case "date-time": - d.dateFormat = time.RFC3339 - case "date": - d.dateFormat = "2006-01-02" - } + oldFormat := d.dateFormat + dateFormat, ok := parseFormatStructTag(field) + if ok { + switch dateFormat { + case "date-time": + d.dateFormat = time.RFC3339 + case "date": + d.dateFormat = "2006-01-02" } - decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} - d.dateFormat = oldFormat } + decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + d.dateFormat = oldFormat } - collectFieldDecoders(t, []int{}) return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { if field := value.FieldByName("JSON"); field.IsValid() { @@ -389,6 +399,11 @@ func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { } } + for _, decoder := range anonymousDecoders { + // ignore errors + decoder.fn(node, value.FieldByIndex(decoder.idx), state) + } + if inlineDecoder != nil { var meta Field dest := value.FieldByIndex(inlineDecoder.idx) diff --git a/internal/apijson/json_test.go b/internal/apijson/json_test.go index 43cea307..72bc4c29 100644 --- a/internal/apijson/json_test.go +++ b/internal/apijson/json_test.go @@ -48,10 +48,32 @@ type TypedAdditionalProperties struct { ExtraFields map[string]int `json:"-,extras"` } +type EmbeddedStruct struct { + A bool `json:"a"` + B string `json:"b"` + + JSON EmbeddedStructJSON +} + +type EmbeddedStructJSON struct { + A Field + B Field + ExtraFields map[string]Field + raw string +} + type EmbeddedStructs struct { - AdditionalProperties - A *int `json:"number2"` + EmbeddedStruct + A *int `json:"a"` ExtraFields map[string]interface{} `json:"-,extras"` + + JSON EmbeddedStructsJSON +} + +type EmbeddedStructsJSON struct { + A Field + ExtraFields map[string]Field + raw string } type Recursive struct { @@ -332,9 +354,34 @@ var tests = map[string]struct { }, }, + "embedded_struct": { + `{"a":1,"b":"bar"}`, + EmbeddedStructs{ + EmbeddedStruct: EmbeddedStruct{ + A: true, + B: "bar", + JSON: EmbeddedStructJSON{ + A: Field{raw: `1`, status: valid}, + B: Field{raw: `"bar"`, status: valid}, + raw: `{"a":1,"b":"bar"}`, + }, + }, + A: P(1), + ExtraFields: map[string]interface{}{"b": "bar"}, + JSON: EmbeddedStructsJSON{ + A: Field{raw: `1`, status: valid}, + ExtraFields: map[string]Field{ + "b": {raw: `"bar"`, status: valid}, + }, + raw: `{"a":1,"b":"bar"}`, + }, + }, + }, + "recursive_struct": { `{"child":{"name":"Alex"},"name":"Robert"}`, - Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, }, + Recursive{Name: "Robert", Child: &Recursive{Name: "Alex"}}, + }, "metadata_coerce": { `{"a":"12","b":"12","c":null,"extra_typed":12,"extra_untyped":{"foo":"bar"}}`,