-
-
Notifications
You must be signed in to change notification settings - Fork 4k
/
encode.go
138 lines (131 loc) · 4.22 KB
/
encode.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package binding
import (
"encoding/base64"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/go-kratos/kratos/v2/encoding/form"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
// EncodeURL encode proto message to url path.
func EncodeURL(pathTemplate string, msg proto.Message, needQuery bool) string {
if msg == nil || (reflect.ValueOf(msg).Kind() == reflect.Ptr && reflect.ValueOf(msg).IsNil()) {
return pathTemplate
}
reg := regexp.MustCompile(`/{[.\w]+}`)
if reg == nil {
return pathTemplate
}
pathParams := make(map[string]struct{})
path := reg.ReplaceAllStringFunc(pathTemplate, func(in string) string {
if len(in) < 4 {
return in
}
key := in[2 : len(in)-1]
vars := strings.Split(key, ".")
if value, err := getValueByField(msg.ProtoReflect(), vars); err == nil {
pathParams[key] = struct{}{}
return "/" + value
}
return in
})
if needQuery {
u, err := form.EncodeMap(msg)
if err == nil && len(u) > 0 {
for key := range pathParams {
delete(u, key)
}
query := u.Encode()
if query != "" {
path += "?" + query
}
}
}
return path
}
func getValueByField(v protoreflect.Message, fieldPath []string) (string, error) {
var fd protoreflect.FieldDescriptor
for i, fieldName := range fieldPath {
fields := v.Descriptor().Fields()
if fd = fields.ByJSONName(fieldName); fd == nil {
fd = fields.ByName(protoreflect.Name(fieldName))
if fd == nil {
return "", fmt.Errorf("field path not found: %q", fieldName)
}
}
if i == len(fieldPath)-1 {
break
}
if fd.Message() == nil || fd.Cardinality() == protoreflect.Repeated {
return "", fmt.Errorf("invalid path: %q is not a message", fieldName)
}
v = v.Get(fd).Message()
}
return encodeField(fd, v.Get(fd))
}
func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil
case protoreflect.EnumKind:
if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" {
return "null", nil
}
desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum())
return string(desc.Name()), nil
case protoreflect.StringKind:
return value.String(), nil
case protoreflect.BytesKind:
return base64.URLEncoding.EncodeToString(value.Bytes()), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return encodeMessage(fieldDescriptor.Message(), value)
default:
return fmt.Sprintf("%v", value.Interface()), nil
}
}
// encodeMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
t, ok := value.Interface().(*timestamppb.Timestamp)
if !ok {
return "", nil
}
return t.AsTime().Format(time.RFC3339Nano), nil
case "google.protobuf.Duration":
d, ok := value.Interface().(*durationpb.Duration)
if !ok {
return "", nil
}
return d.AsDuration().String(), nil
case "google.protobuf.BytesValue":
b, ok := value.Interface().(*wrapperspb.BytesValue)
if !ok {
return "", nil
}
return base64.StdEncoding.EncodeToString(b.Value), nil
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields()
v := value.Message().Get(fd.ByName(protoreflect.Name("value"))).Message()
return fmt.Sprintf("%v", v.Interface()), nil
case "google.protobuf.FieldMask":
m, ok := value.Interface().(*field_mask.FieldMask)
if !ok {
return "", nil
}
return strings.Join(m.Paths, ","), nil
default:
return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
}