Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feature, go): Expose extra response properties #3669

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion generators/go/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.21.3
0.22.0
31 changes: 2 additions & 29 deletions generators/go/internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,8 @@ func (g *Generator) generate(ir *fernir.IntermediateRepresentation, mode Mode) (
files = append(files, modelFiles...)
files = append(files, newStringerFile(g.coordinator))
files = append(files, newTimeFile(g.coordinator))
if needsExtraPropertyHelpers(ir) {
files = append(files, newExtraPropertiesFile(g.coordinator))
files = append(files, newExtraPropertiesTestFile(g.coordinator))
}
files = append(files, newExtraPropertiesFile(g.coordinator))
files = append(files, newExtraPropertiesTestFile(g.coordinator))
// Then handle mode-specific generation tasks.
var rootClientInstantiation *ast.AssignStmt
generatedRootClient := &GeneratedClient{
Expand Down Expand Up @@ -1511,31 +1509,6 @@ func needsPaginationHelpers(ir *fernir.IntermediateRepresentation) bool {
return false
}

// needsExtraPropertyHelpers returns true if at least one object or in-lined request supports
// extra properties, or any unions are specified with samePropertiesAsObject.
func needsExtraPropertyHelpers(ir *fernir.IntermediateRepresentation) bool {
for _, irType := range ir.Types {
if irType.Shape.Object != nil && irType.Shape.Object.ExtraProperties {
return true
}
if irType.Shape.Union != nil {
for _, unionType := range irType.Shape.Union.Types {
if unionType.Shape.SamePropertiesAsObject != nil {
return true
}
}
}
}
for _, irService := range ir.Services {
for _, irEndpoint := range irService.Endpoints {
if irEndpoint.RequestBody != nil && irEndpoint.RequestBody.InlinedRequestBody != nil && irEndpoint.RequestBody.InlinedRequestBody.ExtraProperties {
return true
}
}
}
return false
}

// pointerFunctionNames enumerates all of the pointer function names.
var pointerFunctionNames = map[string]struct{}{
"Bool": struct{}{},
Expand Down
113 changes: 65 additions & 48 deletions generators/go/internal/generator/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,16 @@ func (t *typeVisitor) VisitObject(object *ir.ObjectTypeDeclaration) error {
for _, literal := range objectProperties.literals {
t.writer.P(literal.Name.Name.CamelCase.SafeName, " ", literalToGoType(literal.Value))
}
extraPropertiesFieldName := getExtraPropertiesFieldName(object.ExtraProperties)
if object.ExtraProperties {
t.writer.P()
t.writer.P("ExtraProperties map[string]interface{} `json:\"-\" url:\"-\"`")
t.writer.P(extraPropertiesFieldName, " map[string]interface{} `json:\"-\" url:\"-\"`")
t.writer.P()
} else {
t.writer.P()
t.writer.P(extraPropertiesFieldName, " map[string]interface{}")
}
if t.includeRawJSON {
t.writer.P()
t.writer.P("_rawJSON json.RawMessage")
}
t.writer.P("}")
Expand All @@ -164,6 +168,10 @@ func (t *typeVisitor) VisitObject(object *ir.ObjectTypeDeclaration) error {
receiver := typeNameToReceiver(t.typeName)

// Implement the getter methods.
t.writer.P("func (", receiver, " *", t.typeName, ") GetExtraProperties() map[string]interface{} {")
t.writer.P("return ", receiver, ".", extraPropertiesFieldName)
t.writer.P("}")
t.writer.P()
for _, literal := range objectProperties.literals {
t.writer.P("func (", receiver, " *", t.typeName, ") ", literal.Name.Name.PascalCase.UnsafeName, "()", literalToGoType(literal.Value), "{")
t.writer.P("return ", receiver, ".", literal.Name.Name.CamelCase.SafeName)
Expand All @@ -172,54 +180,54 @@ func (t *typeVisitor) VisitObject(object *ir.ObjectTypeDeclaration) error {
}

// Implement the json.Unmarshaler interface.
if t.includeRawJSON || len(objectProperties.literals) > 0 || len(objectProperties.dates) > 0 || object.ExtraProperties {
if t.includeRawJSON && len(objectProperties.literals) == 0 && len(objectProperties.dates) == 0 && !object.ExtraProperties {
// If we don't require any special unmarshaling, prefer the simpler implementation.
t.writer.P("func (", receiver, " *", t.typeName, ") UnmarshalJSON(data []byte) error {")
t.writer.P("type unmarshaler ", t.typeName)
t.writer.P("var value unmarshaler")
t.writer.P("if err := json.Unmarshal(data, &value); err != nil {")
t.writer.P("return err")
t.writer.P("}")
t.writer.P("*", receiver, " = ", t.typeName, "(value)")
if len(objectProperties.literals) == 0 && len(objectProperties.dates) == 0 && !object.ExtraProperties {
// If we don't require any special unmarshaling, prefer the simpler implementation.
t.writer.P("func (", receiver, " *", t.typeName, ") UnmarshalJSON(data []byte) error {")
t.writer.P("type unmarshaler ", t.typeName)
t.writer.P("var value unmarshaler")
t.writer.P("if err := json.Unmarshal(data, &value); err != nil {")
t.writer.P("return err")
t.writer.P("}")
t.writer.P("*", receiver, " = ", t.typeName, "(value)")
t.writer.P()
writeExtractExtraProperties(t.writer, objectProperties.literals, receiver, extraPropertiesFieldName)
if t.includeRawJSON {
t.writer.P(receiver, "._rawJSON = json.RawMessage(data)")
t.writer.P("return nil")
t.writer.P("}")
t.writer.P()
} else {
t.writer.P("func (", receiver, " *", t.typeName, ") UnmarshalJSON(data []byte) error {")
t.writer.P("type embed ", t.typeName)
t.writer.P("var unmarshaler = struct{")
t.writer.P("embed")
for _, date := range objectProperties.dates {
t.writer.P(date.Name.Name.PascalCase.UnsafeName, " ", date.TypeDeclaration, " ", date.StructTag)
}
t.writer.P("}{")
t.writer.P("embed: embed(*", receiver, "),")
t.writer.P("}")
t.writer.P("if err := json.Unmarshal(data, &unmarshaler); err != nil {")
t.writer.P("return err")
t.writer.P("}")
t.writer.P("*", receiver, " = ", t.typeName, "(unmarshaler.embed)")
for _, date := range objectProperties.dates {
t.writer.P(receiver, ".", date.Name.Name.PascalCase.UnsafeName, " = unmarshaler.", date.Name.Name.PascalCase.UnsafeName, ".", date.TimeMethod)
}
for _, literal := range objectProperties.literals {
t.writer.P(receiver, ".", literal.Name.Name.CamelCase.SafeName, " = ", literalToValue(literal.Value))
}
if object.ExtraProperties {
t.writer.P()
writeExtractExtraProperties(t.writer, objectProperties.literals, receiver)
}
if t.includeRawJSON {
t.writer.P()
t.writer.P(receiver, "._rawJSON = json.RawMessage(data)")
}

t.writer.P("return nil")
t.writer.P("}")
}
t.writer.P("return nil")
t.writer.P("}")
t.writer.P()
} else {
t.writer.P("func (", receiver, " *", t.typeName, ") UnmarshalJSON(data []byte) error {")
t.writer.P("type embed ", t.typeName)
t.writer.P("var unmarshaler = struct{")
t.writer.P("embed")
for _, date := range objectProperties.dates {
t.writer.P(date.Name.Name.PascalCase.UnsafeName, " ", date.TypeDeclaration, " ", date.StructTag)
}
t.writer.P("}{")
t.writer.P("embed: embed(*", receiver, "),")
t.writer.P("}")
t.writer.P("if err := json.Unmarshal(data, &unmarshaler); err != nil {")
t.writer.P("return err")
t.writer.P("}")
t.writer.P("*", receiver, " = ", t.typeName, "(unmarshaler.embed)")
for _, date := range objectProperties.dates {
t.writer.P(receiver, ".", date.Name.Name.PascalCase.UnsafeName, " = unmarshaler.", date.Name.Name.PascalCase.UnsafeName, ".", date.TimeMethod)
}
for _, literal := range objectProperties.literals {
t.writer.P(receiver, ".", literal.Name.Name.CamelCase.SafeName, " = ", literalToValue(literal.Value))
}
t.writer.P()
writeExtractExtraProperties(t.writer, objectProperties.literals, receiver, extraPropertiesFieldName)
if t.includeRawJSON {
t.writer.P()
t.writer.P(receiver, "._rawJSON = json.RawMessage(data)")
}

t.writer.P("return nil")
t.writer.P("}")
t.writer.P()
}

// Implement the json.Marshaler interface.
Expand Down Expand Up @@ -1298,6 +1306,7 @@ func writeExtractExtraProperties(
f *fileWriter,
literals []*literal,
receiver string,
extraPropertiesFieldName string,
) {
var exclude string
if len(literals) > 0 {
Expand All @@ -1310,7 +1319,7 @@ func writeExtractExtraProperties(
f.P("if err != nil {")
f.P("return err")
f.P("}")
f.P(receiver, ".ExtraProperties = extraProperties")
f.P(receiver, ".", extraPropertiesFieldName, " = extraProperties")
f.P()
}

Expand Down Expand Up @@ -1495,6 +1504,14 @@ func tagFormatForType(
return `%s:"%s,omitempty"`
}

func getExtraPropertiesFieldName(extraPropertiesEnabled bool) string {
if extraPropertiesEnabled {
return "ExtraProperties"
}
return "extraProperties"

}

// unknownToGoType maps the given unknown into its Go-equivalent.
func unknownToGoType(_ any) string {
return "interface{}"
Expand Down
6 changes: 3 additions & 3 deletions generators/go/internal/generator/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -1568,8 +1568,8 @@ func newGeneratedEndpoint(

func endpointToIdentifier(endpoint *ir.HttpEndpoint) *generatorexec.EndpointIdentifier {
return &generatorexec.EndpointIdentifier{
Path: fullPathForEndpoint(endpoint),
Method: irMethodToGeneratorExecMethod(endpoint.Method),
Path: fullPathForEndpoint(endpoint),
Method: irMethodToGeneratorExecMethod(endpoint.Method),
IdentifierOverride: &endpoint.Id,
}
}
Expand Down Expand Up @@ -2535,7 +2535,7 @@ func (f *fileWriter) WriteRequestType(
}
if requestBody.extraProperties {
f.P()
writeExtractExtraProperties(f, literals, receiver)
writeExtractExtraProperties(f, literals, receiver, getExtraPropertiesFieldName(requestBody.extraProperties))
}
f.P("return nil")
f.P("}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package core

import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"strings"
)

// MarshalJSONWithExtraProperty marshals the given value to JSON, including the extra property.
func MarshalJSONWithExtraProperty(marshaler interface{}, key string, value interface{}) ([]byte, error) {
return MarshalJSONWithExtraProperties(marshaler, map[string]interface{}{key: value})
}

// MarshalJSONWithExtraProperties marshals the given value to JSON, including any extra properties.
func MarshalJSONWithExtraProperties(marshaler interface{}, extraProperties map[string]interface{}) ([]byte, error) {
bytes, err := json.Marshal(marshaler)
if err != nil {
return nil, err
}
if len(extraProperties) == 0 {
return bytes, nil
}
keys, err := getKeys(marshaler)
if err != nil {
return nil, err
}
for _, key := range keys {
if _, ok := extraProperties[key]; ok {
return nil, fmt.Errorf("cannot add extra property %q because it is already defined on the type", key)
}
}
extraBytes, err := json.Marshal(extraProperties)
if err != nil {
return nil, err
}
if isEmptyJSON(bytes) {
if isEmptyJSON(extraBytes) {
return bytes, nil
}
return extraBytes, nil
}
result := bytes[:len(bytes)-1]
result = append(result, ',')
result = append(result, extraBytes[1:len(extraBytes)-1]...)
result = append(result, '}')
return result, nil
}

// ExtractExtraProperties extracts any extra properties from the given value.
func ExtractExtraProperties(bytes []byte, value interface{}, exclude ...string) (map[string]interface{}, error) {
val := reflect.ValueOf(value)
for val.Kind() == reflect.Ptr {
if val.IsNil() {
return nil, fmt.Errorf("value must be non-nil to extract extra properties")
}
val = val.Elem()
}
if err := json.Unmarshal(bytes, &value); err != nil {
return nil, err
}
var extraProperties map[string]interface{}
if err := json.Unmarshal(bytes, &extraProperties); err != nil {
return nil, err
}
for i := 0; i < val.Type().NumField(); i++ {
key := jsonKey(val.Type().Field(i))
if key == "" || key == "-" {
continue
}
delete(extraProperties, key)
}
for _, key := range exclude {
delete(extraProperties, key)
}
if len(extraProperties) == 0 {
return nil, nil
}
return extraProperties, nil
}

// getKeys returns the keys associated with the given value. The value must be a
// a struct or a map with string keys.
func getKeys(value interface{}) ([]string, error) {
val := reflect.ValueOf(value)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if !val.IsValid() {
return nil, nil
}
switch val.Kind() {
case reflect.Struct:
return getKeysForStructType(val.Type()), nil
case reflect.Map:
var keys []string
if val.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("cannot extract keys from %T; only structs and maps with string keys are supported", value)
}
for _, key := range val.MapKeys() {
keys = append(keys, key.String())
}
return keys, nil
default:
return nil, fmt.Errorf("cannot extract keys from %T; only structs and maps with string keys are supported", value)
}
}

// getKeysForStructType returns all the keys associated with the given struct type,
// visiting embedded fields recursively.
func getKeysForStructType(structType reflect.Type) []string {
if structType.Kind() == reflect.Pointer {
structType = structType.Elem()
}
if structType.Kind() != reflect.Struct {
return nil
}
var keys []string
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
if field.Anonymous {
keys = append(keys, getKeysForStructType(field.Type)...)
continue
}
keys = append(keys, jsonKey(field))
}
return keys
}

// jsonKey returns the JSON key from the struct tag of the given field,
// excluding the omitempty flag (if any).
func jsonKey(field reflect.StructField) string {
return strings.TrimSuffix(field.Tag.Get("json"), ",omitempty")
}

// isEmptyJSON returns true if the given data is empty, the empty JSON object, or
// an explicit null.
func isEmptyJSON(data []byte) bool {
return len(data) <= 2 || bytes.Equal(data, []byte("null"))
}
Loading
Loading