-
Notifications
You must be signed in to change notification settings - Fork 883
/
slice_codec.go
182 lines (151 loc) · 3.83 KB
/
slice_codec.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
package bson
import (
"fmt"
"reflect"
)
var defaultSliceCodec = &SliceCodec{}
// SliceCodec is the Codec used for slice and array values.
type SliceCodec struct{}
var _ Codec = &SliceCodec{}
// EncodeValue implements the Codec interface.
func (sc *SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, i interface{}) error {
val := reflect.ValueOf(i)
switch val.Kind() {
case reflect.Array:
case reflect.Slice:
if val.IsNil() { // When nil, special case to null
return vw.WriteNull()
}
default:
return fmt.Errorf("%T can only encode arrays and slices", sc)
}
length := val.Len()
aw, err := vw.WriteArray()
if err != nil {
return err
}
// We do this outside of the loop because an array or a slice can only have
// one element type. If it's the empty interface, we'll use the empty
// interface codec.
var codec Codec
switch val.Type().Elem() {
case tElement:
codec = defaultElementCodec
default:
codec, err = ec.Lookup(val.Type().Elem())
if err != nil {
return err
}
}
for idx := 0; idx < length; idx++ {
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = codec.EncodeValue(ec, vw, val.Index(idx).Interface())
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// DecodeValue implements the Codec interface.
func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, i interface{}) error {
val := reflect.ValueOf(i)
if !val.IsValid() || val.Kind() != reflect.Ptr || val.IsNil() {
return fmt.Errorf("%T can only be used to decode non-nil pointers to slice or array values, got %T", sc, i)
}
switch val.Elem().Kind() {
case reflect.Slice, reflect.Array:
if !val.Elem().CanSet() {
return fmt.Errorf("%T can only decode settable slice and array values", sc)
}
default:
return fmt.Errorf("%T can only decode settable slice and array values, got %T", sc, i)
}
switch vr.Type() {
case TypeArray:
case TypeNull:
if val.Elem().Kind() != reflect.Slice {
return fmt.Errorf("cannot decode %v into an array", vr.Type())
}
null := reflect.Zero(val.Elem().Type())
val.Elem().Set(null)
return vr.ReadNull()
default:
return fmt.Errorf("cannot decode %v into a slice", vr.Type())
}
eType := val.Type().Elem().Elem()
ar, err := vr.ReadArray()
if err != nil {
return err
}
var elems []reflect.Value
switch eType {
case tElement:
elems, err = sc.decodeElement(dc, ar)
default:
elems, err = sc.decodeDefault(dc, ar, eType)
}
if err != nil {
return err
}
switch val.Elem().Kind() {
case reflect.Slice:
slc := reflect.MakeSlice(val.Elem().Type(), len(elems), len(elems))
for idx, elem := range elems {
slc.Index(idx).Set(elem)
}
val.Elem().Set(slc)
case reflect.Array:
if len(elems) > val.Elem().Len() {
return fmt.Errorf("more elements returned in array than can fit inside %s", val.Elem().Type())
}
for idx, elem := range elems {
val.Elem().Index(idx).Set(elem)
}
}
return nil
}
func (sc *SliceCodec) decodeElement(dc DecodeContext, ar ArrayReader) ([]reflect.Value, error) {
elems := make([]reflect.Value, 0)
for {
vr, err := ar.ReadValue()
if err == ErrEOA {
break
}
if err != nil {
return nil, err
}
var elem *Element
err = defaultElementCodec.decodeValue(dc, vr, "", &elem)
if err != nil {
return nil, err
}
elems = append(elems, reflect.ValueOf(elem))
}
return elems, nil
}
func (sc *SliceCodec) decodeDefault(dc DecodeContext, ar ArrayReader, eType reflect.Type) ([]reflect.Value, error) {
elems := make([]reflect.Value, 0)
codec, err := dc.Lookup(eType)
if err != nil {
return nil, err
}
for {
vr, err := ar.ReadValue()
if err == ErrEOA {
break
}
if err != nil {
return nil, err
}
ptr := reflect.New(eType)
err = codec.DecodeValue(dc, vr, ptr.Interface())
if err != nil {
return nil, err
}
elems = append(elems, ptr.Elem())
}
return elems, nil
}