Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamic message + wkt #12

Merged
merged 5 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ linters-settings:
keywords: [FIXME]
varnamelen:
ignore-decls:
- ok bool
- T any
- i int
- wg sync.WaitGroup
Expand Down
198 changes: 157 additions & 41 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ func (u *unmarshaler) checkTag(node *yaml.Node, expected string) {
}
}

func (u *unmarshaler) findAnyType(node *yaml.Node) protoreflect.MessageType {
if len(node.Content) == 0 {
return nil
}
func (u *unmarshaler) findAnyTypeURL(node *yaml.Node) string {
typeURL := ""
for i := 1; i < len(node.Content); i += 2 {
keyNode := node.Content[i-1]
Expand All @@ -161,10 +158,10 @@ func (u *unmarshaler) findAnyType(node *yaml.Node) protoreflect.MessageType {
break
}
}
if typeURL == "" {
return nil
}
return typeURL
}

func (u *unmarshaler) resolveAnyType(typeURL string) (protoreflect.MessageType, error) {
// Get the message type.
var msgType protoreflect.MessageType
var err error
Expand All @@ -174,10 +171,17 @@ func (u *unmarshaler) findAnyType(node *yaml.Node) protoreflect.MessageType {
msgType, err = protoregistry.GlobalTypes.FindMessageByURL(typeURL)
}
if err != nil {
u.addErrorf(node, "unknown type %q: %v", typeURL, err)
return nil
return nil, err
}
return msgType
return msgType, nil
}

func (u *unmarshaler) findAnyType(node *yaml.Node) (protoreflect.MessageType, error) {
typeURL := u.findAnyTypeURL(node)
if typeURL == "" {
return nil, errors.New("missing @type field")
}
return u.resolveAnyType(typeURL)
}

func (u *unmarshaler) findType(msgDesc protoreflect.MessageDescriptor) protoreflect.MessageType {
Expand Down Expand Up @@ -572,17 +576,20 @@ func (u *unmarshaler) unmarshalMap(node *yaml.Node, field protoreflect.FieldDesc
}
}

func isNull(node *yaml.Node) bool {
return node.Tag == "!!null"
}

// Unmarshal the given yaml node into the given proto.Message.
func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message) {
if node.Tag == "!!null" {
return // Null is always allowed for messages
}

// Check for a custom unmarshaler
custom, ok := u.custom[message.ProtoReflect().Descriptor().FullName()]
if ok && custom(u, node, message) {
return // Custom unmarshaler handled the decoding
}
if isNull(node) {
return // Null is always allowed for messages
}
if node.Kind != yaml.MappingNode {
u.addErrorf(node, "expected fields for %v, got %v",
message.ProtoReflect().Descriptor().FullName(), getNodeKind(node.Kind))
Expand Down Expand Up @@ -629,20 +636,30 @@ func addWktUnmarshalers(custom map[protoreflect.FullName]customUnmarshaler) {
}

func unmarshalAnyMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
anyVal, ok := message.(*anypb.Any)
if !ok || !unm.checkKind(node, yaml.MappingNode) || len(node.Content) == 0 {
if node.Kind != yaml.MappingNode || len(node.Content) == 0 {
return false
}
anyVal, ok := message.(*anypb.Any)
if !ok {
anyVal = &anypb.Any{}
}

// Get the message type
msgType := unm.findAnyType(node)
if msgType != nil {
protoVal := msgType.New()
unm.unmarshalMessage(node, protoVal.Interface())
err := anyVal.MarshalFrom(protoVal.Interface())
if err != nil {
unm.addErrorf(node, "failed to marshal %v: %v", msgType.Descriptor().FullName(), err)
}
// Get the message type.
msgType, err := unm.findAnyType(node)
if err != nil {
unm.addError(node, err)
return true
}

protoVal := msgType.New()
unm.unmarshalMessage(node, protoVal.Interface())
if err = anyVal.MarshalFrom(protoVal.Interface()); err != nil {
unm.addErrorf(node, "failed to marshal %v: %v", msgType.Descriptor().FullName(), err)
}

if !ok {
return setFieldByName(message, "type_url", protoreflect.ValueOfString(anyVal.TypeUrl)) &&
setFieldByName(message, "value", protoreflect.ValueOfBytes(anyVal.Value))
}

return true
Expand Down Expand Up @@ -687,7 +704,7 @@ func parseDuration(txt string, duration *durationpb.Duration) error {
if power < 0 {
return errors.New("too many fractional second digits")
}
nanos *= 10 ^ int64(power)
nanos *= int64(math.Pow10(power))
duration.Nanos = int32(nanos)
default:
return errors.New("invalid duration: too many '.' characters")
Expand Down Expand Up @@ -721,33 +738,57 @@ func parseTimestamp(txt string, timestamp *timestamppb.Timestamp) error {
return nil
}

func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
duration, ok := message.(*durationpb.Duration)
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || !ok {
func setFieldByName(message proto.Message, name string, value protoreflect.Value) bool {
field := message.ProtoReflect().Descriptor().Fields().ByName(protoreflect.Name(name))
if field == nil {
return false
}
message.ProtoReflect().Set(field, value)
return true
}

func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) {
return false
}
duration, ok := message.(*durationpb.Duration)
if !ok {
duration = &durationpb.Duration{}
}
err := parseDuration(node.Value, duration)
if err != nil {
unm.addErrorf(node, "invalid duration: %v", err)
} else if !ok {
// Set the fields dynamically.
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.Seconds)) &&
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.Nanos))
}
return true
}

func unmarshalTimestampMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
timestamp, ok := message.(*timestamppb.Timestamp)
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || !ok {
if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) {
return false
}
timestamp, ok := message.(*timestamppb.Timestamp)
if !ok {
timestamp = &timestamppb.Timestamp{}
}
err := parseTimestamp(node.Value, timestamp)
if err != nil {
unm.addErrorf(node, "invalid timestamp: %v", err)
} else if !ok {
return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(timestamp.Seconds)) &&
setFieldByName(message, "nanos", protoreflect.ValueOfInt32(timestamp.Nanos))
}
return true
}

// Forwards unmarshaling to the "value" field of the given wrapper message.
func unmarshalWrapperMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
if isNull(node) {
return true
}
valueField := message.ProtoReflect().Descriptor().Fields().ByName("value")
if node.Kind == yaml.MappingNode || valueField == nil {
return false
Expand All @@ -756,30 +797,105 @@ func unmarshalWrapperMsg(unm *unmarshaler, node *yaml.Node, message proto.Messag
return true
}

func unmarshalValueMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
if value, ok := message.(*structpb.Value); ok {
unmarshalValue(unm, node, value, nil, false)
return true
func dynSetValue(message proto.Message, value *structpb.Value) bool {
switch val := value.Kind.(type) {
case *structpb.Value_NullValue:
return setFieldByName(message, "null_value", protoreflect.ValueOfEnum(protoreflect.EnumNumber(val.NullValue)))
case *structpb.Value_NumberValue:
return setFieldByName(message, "number_value", protoreflect.ValueOfFloat64(val.NumberValue))
case *structpb.Value_StringValue:
return setFieldByName(message, "string_value", protoreflect.ValueOfString(val.StringValue))
case *structpb.Value_BoolValue:
return setFieldByName(message, "bool_value", protoreflect.ValueOfBool(val.BoolValue))
case *structpb.Value_ListValue:
listFld := message.ProtoReflect().Descriptor().Fields().ByName("list_value")
if listFld == nil {
return false
}
listVal := message.ProtoReflect().Mutable(listFld).Message().Interface()
return dynSetListValue(listVal, val.ListValue)
case *structpb.Value_StructValue:
structFld := message.ProtoReflect().Descriptor().Fields().ByName("struct_value")
if structFld == nil {
return false
}
structVal := message.ProtoReflect().Mutable(structFld).Message().Interface()
return dynSetStruct(structVal, val.StructValue)
}
return false
}

func unmarshalListValueMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
listValue, ok := message.(*structpb.ListValue)
if !ok || node.Kind != yaml.SequenceNode {
func dynSetListValue(message proto.Message, list *structpb.ListValue) bool {
valuesFld := message.ProtoReflect().Descriptor().Fields().ByName("values")
if valuesFld == nil {
return false
}
values := message.ProtoReflect().Mutable(valuesFld).List()
for _, item := range list.Values {
value := values.NewElement()
if !dynSetValue(value.Message().Interface(), item) {
return false
}
values.Append(value)
}
return true
}

func dynSetStruct(message proto.Message, structVal *structpb.Struct) bool {
fieldsFld := message.ProtoReflect().Descriptor().Fields().ByName("fields")
if fieldsFld == nil {
return false
}
fields := message.ProtoReflect().Mutable(fieldsFld).Map()
for key, item := range structVal.Fields {
value := fields.NewValue()
if !dynSetValue(value.Message().Interface(), item) {
return false
}
fields.Set(protoreflect.ValueOfString(key).MapKey(), value)
}
return true
}

func unmarshalValueMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
value, ok := message.(*structpb.Value)
if !ok {
value = &structpb.Value{}
}
unmarshalValue(unm, node, value, nil, false)
if !ok {
return dynSetValue(message, value)
}
return true
}

func unmarshalListValueMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
if node.Kind != yaml.SequenceNode {
return false
}
listValue, ok := message.(*structpb.ListValue)
if !ok {
listValue = &structpb.ListValue{}
}
unmarshalListValue(unm, node, listValue, nil)
if !ok {
return dynSetListValue(message, listValue)
}
return true
}

func unmarshalStructMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
structVal, ok := message.(*structpb.Struct)
if !ok || node.Kind != yaml.MappingNode {
if node.Kind != yaml.MappingNode {
return false
}
structVal, ok := message.(*structpb.Struct)
if !ok {
structVal = &structpb.Struct{}
}
unmarshalStruct(unm, node, structVal, nil, nil)
if !ok {
return dynSetStruct(message, structVal)
}
return true
}

Expand Down Expand Up @@ -860,7 +976,7 @@ func unmarshalStruct(
}
} else if msgDesc == nil {
// Try to find the message descriptor.
msgType := unm.findAnyType(node)
msgType, _ := unm.findAnyType(node)
if msgType != nil {
msgDesc = msgType.Descriptor()
}
Expand Down
Loading
Loading