diff --git a/bson/array_codec.go b/bson/array_codec.go index 4642fb6ea2..d235b69805 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -15,16 +15,6 @@ import ( // arrayCodec is the Codec used for bsoncore.Array values. type arrayCodec struct{} -// EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCoreArray { - return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} - } - - arr := val.Interface().(bsoncore.Array) - return copyArrayFromBytes(vw, arr) -} - // DecodeValue is the ValueDecoder for bsoncore.Array values. func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreArray { diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 80e13e7d81..276b02f80c 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -139,6 +139,20 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref return fn(ec, vw, val) } +// reflectFreeValueEncoder is a reflect-free version of ValueEncoder. +type reflectFreeValueEncoder interface { + EncodeValue(ec EncodeContext, vw ValueWriter, val any) error +} + +// reflectFreeValueEncoderFunc is an adapter function that allows a function +// with the correct signature to be used as a reflectFreeValueEncoder. +type reflectFreeValueEncoderFunc func(ec EncodeContext, vw ValueWriter, val any) error + +// EncodeValue implements the reflectFreeValueEncoder interface. +func (fn reflectFreeValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val any) error { + return fn(ec, vw, val) +} + // ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. // Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc, // ValueDecoderFunc is provided to allow the use of a function with the correct signature as a diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index bd44cf9a89..d6d27fcc86 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -12,28 +12,13 @@ import ( ) // byteSliceCodec is the Codec used for []byte values. -type byteSliceCodec struct { - // encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values - // instead of BSON null. - encodeNilAsEmpty bool -} +type byteSliceCodec struct{} // Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be // used by collection type decoders (e.g. map, slice, etc) to set individual values in a // collection. var _ typeDecoder = &byteSliceCodec{} -// EncodeValue is the ValueEncoder for []byte. -func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tByteSlice { - return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { - return vw.WriteNull() - } - return vw.WriteBinary(val.Interface().([]byte)) -} - func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ diff --git a/bson/codec_cache.go b/bson/codec_cache.go index b4042822e6..7bf8f7c419 100644 --- a/bson/codec_cache.go +++ b/bson/codec_cache.go @@ -58,6 +58,39 @@ func (c *typeEncoderCache) Clone() *typeEncoderCache { return cc } +type reflectFreeTypeEncoderCache struct { + cache sync.Map // map[reflect.Type]typeReflectFreeEncoderCache +} + +func (c *reflectFreeTypeEncoderCache) Store(rt reflect.Type, enc reflectFreeValueEncoder) { + c.cache.Store(rt, enc) +} + +func (c *reflectFreeTypeEncoderCache) Load(rt reflect.Type) (reflectFreeValueEncoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(reflectFreeValueEncoder), true + } + return nil, false +} + +func (c *reflectFreeTypeEncoderCache) LoadOrStore(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { + if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { + enc = v.(reflectFreeValueEncoder) + } + return enc +} + +func (c *reflectFreeTypeEncoderCache) Clone() *reflectFreeTypeEncoderCache { + cc := new(reflectFreeTypeEncoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + type typeDecoderCache struct { cache sync.Map // map[reflect.Type]ValueDecoder } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 4dad538a26..7ea2e55efd 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -3414,20 +3414,22 @@ func TestDefaultValueDecoders(t *testing.T) { // the top-level to decode to registered type when unmarshalling to interface{} topLevelReg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(topLevelReg) registerDefaultDecoders(topLevelReg) topLevelReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) embeddedReg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(embeddedReg) registerDefaultDecoders(embeddedReg) @@ -3470,10 +3472,11 @@ func TestDefaultValueDecoders(t *testing.T) { // type information is not available. reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -3564,10 +3567,11 @@ func TestDefaultValueDecoders(t *testing.T) { // Use a registry that has all default decoders with the custom interface{} decoder that always errors. nestedRegistry := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultDecoders(nestedRegistry) nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) @@ -3721,10 +3725,11 @@ func TestDefaultValueDecoders(t *testing.T) { ) reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultDecoders(reg) reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) @@ -3795,10 +3800,11 @@ func buildDocument(elems []byte) []byte { func buildDefaultRegistry() *Registry { reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index bd5a20f2f9..881acfe73d 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -13,6 +13,7 @@ import ( "net/url" "reflect" "sync" + "time" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -59,28 +60,34 @@ func registerDefaultEncoders(reg *Registry) { mapEncoder := &mapCodec{} uintCodec := &uintCodec{} - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) - reg.RegisterTypeEncoder(tTime, &timeCodec{}) - reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) - reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) - reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) - reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) - reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) - reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) - reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) - reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) - reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue)) - reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) - reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) - reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) - reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) - reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) - reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) - reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) - reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) - reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) - reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) + // Register the reflect-free default type encoders. + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) + reg.registerReflectFreeTypeEncoder(tTime, reflectFreeValueEncoderFunc(timeEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCoreArray, reflectFreeValueEncoderFunc(coreArrayEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tNull, reflectFreeValueEncoderFunc(nullEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tOID, reflectFreeValueEncoderFunc(objectIDEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDecimal, reflectFreeValueEncoderFunc(decimal128EncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tJSONNumber, reflectFreeValueEncoderFunc(jsonNumberEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tURL, reflectFreeValueEncoderFunc(urlEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tJavaScript, reflectFreeValueEncoderFunc(javaScriptEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tSymbol, reflectFreeValueEncoderFunc(symbolEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tBinary, reflectFreeValueEncoderFunc(binaryEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tVector, reflectFreeValueEncoderFunc(vectorEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tUndefined, reflectFreeValueEncoderFunc(undefinedEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDateTime, reflectFreeValueEncoderFunc(dateTimeEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tRegex, reflectFreeValueEncoderFunc(regexEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDBPointer, reflectFreeValueEncoderFunc(dbPointerEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tTimestamp, reflectFreeValueEncoderFunc(timestampEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tMinKey, reflectFreeValueEncoderFunc(minKeyEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tMaxKey, reflectFreeValueEncoderFunc(maxKeyEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCoreDocument, reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCodeWithScope, reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF)) + + // Register the reflect-based default encoders. + reg.RegisterTypeEncoder(tEmpty, ValueEncoderFunc(emptyInterfaceValue)) + + // Register the kind-based default encoders. These must continue using + // reflection since they account for custom types that cannot be anticipated. reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue)) reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue)) @@ -100,6 +107,8 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterKindEncoder(reflect.String, &stringCodec{}) reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder)) reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{}) + + // Register the interface-based default encoders. reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) } @@ -142,7 +151,21 @@ func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { } } -// floatEncodeValue is the ValueEncoderFunc for float types. +func floatEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if f32, ok := val.(float32); ok { + return vw.WriteDouble(float64(f32)) + } + + if f64, ok := val.(float64); ok { + return vw.WriteDouble(f64) + } + + return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: reflect.ValueOf(val)} +} + +// floatEncodeValue is the ValueEncoderFunc for float types. this function is +// used to decode "types" and "kinds" and therefore cannot be a wrapper for +// reflection-free decoding in the default "type" case. func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: @@ -153,27 +176,29 @@ func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tOID { - return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} +func objectIDEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + objID, ok := val.(ObjectID) + if !ok { + return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: reflect.ValueOf(val)} } - return vw.WriteObjectID(val.Interface().(ObjectID)) + + return vw.WriteObjectID(objID) } -// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDecimal { - return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} +func decimal128EncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + d128, ok := val.(Decimal128) + if !ok { + return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: reflect.ValueOf(val)} } - return vw.WriteDecimal128(val.Interface().(Decimal128)) + + return vw.WriteDecimal128(d128) } -// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. -func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tJSONNumber { - return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} +func jsonNumberEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { + jsnum, ok := val.(json.Number) + if !ok { + return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: reflect.ValueOf(val)} } - jsnum := val.Interface().(json.Number) // Attempt int first, then float64 if i64, err := jsnum.Int64(); err == nil { @@ -185,15 +210,15 @@ func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) return err } - return floatEncodeValue(ec, vw, reflect.ValueOf(f64)) + return floatEncodeValueRF(ec, vw, f64) } -// urlEncodeValue is the ValueEncoderFunc for url.URL. -func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tURL { - return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} +func urlEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + u, ok := val.(url.URL) + if !ok { + return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: reflect.ValueOf(val)} } - u := val.Interface().(url.URL) + return vw.WriteString(u.String()) } @@ -337,145 +362,126 @@ func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er return copyValueFromBytes(vw, TypeEmbeddedDocument, data) } -// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tJavaScript { - return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} +func javaScriptEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + jsString, ok := val.(JavaScript) + if !ok { + return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: reflect.ValueOf(val)} } - return vw.WriteJavascript(val.String()) + return vw.WriteJavascript(string(jsString)) } -// symbolEncodeValue is the ValueEncoderFunc for the Symbol type. -func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tSymbol { - return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} +func symbolEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + symbol, ok := val.(Symbol) + if !ok { + return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: reflect.ValueOf(val)} } - return vw.WriteSymbol(val.String()) + return vw.WriteSymbol(string(symbol)) } -// binaryEncodeValue is the ValueEncoderFunc for Binary. -func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tBinary { - return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} +func binaryEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + b, ok := val.(Binary) + if !ok { + return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: reflect.ValueOf(val)} } - b := val.Interface().(Binary) return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// vectorEncodeValue is the ValueEncoderFunc for Vector. -func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - t := val.Type() - if !val.IsValid() || t != tVector { - return ValueEncoderError{Name: "VectorEncodeValue", - Types: []reflect.Type{tVector}, - Received: val, - } +func vectorEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + v, ok := val.(Vector) + if !ok { + return ValueEncoderError{Name: "VectorEncodeValue", Types: []reflect.Type{tVector}, Received: reflect.ValueOf(val)} } - v := val.Interface().(Vector) + b := v.Binary() return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// undefinedEncodeValue is the ValueEncoderFunc for Undefined. -func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tUndefined { - return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} +func undefinedEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(Undefined); !ok { + return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: reflect.ValueOf(val)} } return vw.WriteUndefined() } -// dateTimeEncodeValue is the ValueEncoderFunc for DateTime. -func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDateTime { - return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} +func dateTimeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + dateTime, ok := val.(DateTime) + if !ok { + return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: reflect.ValueOf(val)} } - return vw.WriteDateTime(val.Int()) + return vw.WriteDateTime(int64(dateTime)) } -// nullEncodeValue is the ValueEncoderFunc for Null. -func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tNull { - return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} +func nullEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(Null); !ok { + return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: reflect.ValueOf(val)} } return vw.WriteNull() } -// regexEncodeValue is the ValueEncoderFunc for Regex. -func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tRegex { - return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} +func regexEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + regex, ok := val.(Regex) + if !ok { + return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: reflect.ValueOf(val)} } - regex := val.Interface().(Regex) - return vw.WriteRegex(regex.Pattern, regex.Options) } -// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDBPointer { - return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} +func dbPointerEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + dbp, ok := val.(DBPointer) + if !ok { + return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: reflect.ValueOf(val)} } - dbp := val.Interface().(DBPointer) - return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } -// timestampEncodeValue is the ValueEncoderFunc for Timestamp. -func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTimestamp { - return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} +func timestampEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + ts, ok := val.(Timestamp) + if !ok { + return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: reflect.ValueOf(val)} } - ts := val.Interface().(Timestamp) - return vw.WriteTimestamp(ts.T, ts.I) } -// minKeyEncodeValue is the ValueEncoderFunc for MinKey. -func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tMinKey { - return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} +func minKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(MinKey); !ok { + return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: reflect.ValueOf(val)} } return vw.WriteMinKey() } -// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tMaxKey { - return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} +func maxKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(MaxKey); !ok { + return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: reflect.ValueOf(val)} } return vw.WriteMaxKey() } -// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCoreDocument { - return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} +func coreDocumentEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + cdoc, ok := val.(bsoncore.Document) + if !ok { + return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: reflect.ValueOf(val)} } - cdoc := val.Interface().(bsoncore.Document) - return copyDocumentFromBytes(vw, cdoc) } -// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCodeWithScope { - return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} +func codeWithScopeEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { + cws, ok := val.(CodeWithScope) + if !ok { + return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: reflect.ValueOf(val)} } - cws := val.Interface().(CodeWithScope) - dw, err := vw.WriteCodeWithScope(string(cws.Code)) if err != nil { return err @@ -489,7 +495,6 @@ func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Valu scopeVW.reset(scopeVW.buf[:0]) scopeVW.w = sw defer bvwPool.Put(scopeVW) - encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) if err != nil { return err @@ -515,3 +520,58 @@ func isImplementationNil(val reflect.Value, inter reflect.Type) bool { } return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() } + +func byteSliceEncodeValueRF(encodeNilAsEmpty bool) reflectFreeValueEncoderFunc { + return reflectFreeValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val any) error { + byteSlice, ok := val.([]byte) + if !ok { + return ValueEncoderError{ + Name: "ByteSliceEncodeValue", + Types: []reflect.Type{tByteSlice}, + Received: reflect.ValueOf(val), + } + } + + if byteSlice == nil && !encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { + return vw.WriteNull() + } + + return vw.WriteBinary(byteSlice) + }) +} + +func timeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + tt, ok := val.(time.Time) + if !ok { + return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(val)} + } + + dt := NewDateTimeFromTime(tt) + return vw.WriteDateTime(int64(dt)) +} + +func coreArrayEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + arr, ok := val.(bsoncore.Array) + if !ok { + return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: reflect.ValueOf(val)} + } + + return copyArrayFromBytes(vw, arr) +} + +func emptyInterfaceValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tEmpty { + return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} + } + + if val.IsNil() { + return vw.WriteNull() + } + + encoder, err := ec.LookupEncoder(val.Elem().Type()) + if err != nil { + return err + } + + return encoder.EncodeValue(ec, vw, val.Elem()) +} diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index e15019785d..fa01fdbacb 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -73,11 +73,13 @@ func TestDefaultValueEncoders(t *testing.T) { testCases := []struct { name string ve ValueEncoder + rfve reflectFreeValueEncoder subtests []subtest }{ { "BooleanEncodeValue", ValueEncoderFunc(booleanEncodeValue), + nil, []subtest{ { "wrong type", @@ -94,6 +96,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "IntEncodeValue", ValueEncoderFunc(intEncodeValue), + nil, []subtest{ { "wrong type", @@ -134,6 +137,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "UintEncodeValue", &uintCodec{}, + nil, []subtest{ { "wrong type", @@ -175,6 +179,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "FloatEncodeValue", ValueEncoderFunc(floatEncodeValue), + nil, []subtest{ { "wrong type", @@ -194,9 +199,34 @@ func TestDefaultValueEncoders(t *testing.T) { {"float64/reflection path", myfloat64(3.14159), nil, nil, writeDouble, nil}, }, }, + { + "reflection free FloatEncodeValue", + nil, + reflectFreeValueEncoderFunc(floatEncodeValueRF), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{ + Name: "FloatEncodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: reflect.ValueOf(wrong), + }, + }, + // the reflection free encoder function should only be used for + // encoding "types", not "kinds". So the reflection path tests are not + // valid. + {"float32/fast path", float32(3.14159), nil, nil, writeDouble, nil}, + {"float64/fast path", float64(3.14159), nil, nil, writeDouble, nil}, + }, + }, { "TimeEncodeValue", - &timeCodec{}, + nil, + reflectFreeValueEncoderFunc(timeEncodeValueRF), []subtest{ { "wrong type", @@ -212,6 +242,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MapEncodeValue", &mapCodec{}, + nil, []subtest{ { "wrong kind", @@ -292,6 +323,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ArrayEncodeValue", ValueEncoderFunc(arrayEncodeValue), + nil, []subtest{ { "wrong kind", @@ -370,6 +402,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "SliceEncodeValue", &sliceCodec{}, + nil, []subtest{ { "wrong kind", @@ -455,7 +488,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ObjectIDEncodeValue", - ValueEncoderFunc(objectIDEncodeValue), + nil, + reflectFreeValueEncoderFunc(objectIDEncodeValueRF), []subtest{ { "wrong type", @@ -474,7 +508,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "Decimal128EncodeValue", - ValueEncoderFunc(decimal128EncodeValue), + nil, + reflectFreeValueEncoderFunc(decimal128EncodeValueRF), []subtest{ { "wrong type", @@ -489,7 +524,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JSONNumberEncodeValue", - ValueEncoderFunc(jsonNumberEncodeValue), + nil, + reflectFreeValueEncoderFunc(jsonNumberEncodeValueRF), []subtest{ { "wrong type", @@ -518,7 +554,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "URLEncodeValue", - ValueEncoderFunc(urlEncodeValue), + nil, + reflectFreeValueEncoderFunc(urlEncodeValueRF), []subtest{ { "wrong type", @@ -533,7 +570,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ByteSliceEncodeValue", - &byteSliceCodec{}, + nil, + byteSliceEncodeValueRF(false), []subtest{ { "wrong type", @@ -549,7 +587,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "EmptyInterfaceEncodeValue", - &emptyInterfaceCodec{}, + ValueEncoderFunc(emptyInterfaceValue), + nil, []subtest{ { "wrong type", @@ -564,6 +603,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ValueMarshalerEncodeValue", ValueEncoderFunc(valueMarshalerEncodeValue), + nil, []subtest{ { "wrong type", @@ -642,6 +682,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MarshalerEncodeValue", ValueEncoderFunc(marshalerEncodeValue), + nil, []subtest{ { "wrong type", @@ -704,6 +745,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "PointerCodec.EncodeValue", &pointerCodec{}, + nil, []subtest{ { "nil", @@ -742,6 +784,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "pointer implementation addressable interface", &pointerCodec{}, + nil, []subtest{ { "ValueMarshaler", @@ -763,7 +806,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JavaScriptEncodeValue", - ValueEncoderFunc(javaScriptEncodeValue), + nil, + reflectFreeValueEncoderFunc(javaScriptEncodeValueRF), []subtest{ { "wrong type", @@ -778,7 +822,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SymbolEncodeValue", - ValueEncoderFunc(symbolEncodeValue), + nil, + reflectFreeValueEncoderFunc(symbolEncodeValueRF), []subtest{ { "wrong type", @@ -793,7 +838,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "BinaryEncodeValue", - ValueEncoderFunc(binaryEncodeValue), + nil, + reflectFreeValueEncoderFunc(binaryEncodeValueRF), []subtest{ { "wrong type", @@ -808,7 +854,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UndefinedEncodeValue", - ValueEncoderFunc(undefinedEncodeValue), + nil, + reflectFreeValueEncoderFunc(undefinedEncodeValueRF), []subtest{ { "wrong type", @@ -823,7 +870,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DateTimeEncodeValue", - ValueEncoderFunc(dateTimeEncodeValue), + nil, + reflectFreeValueEncoderFunc(dateTimeEncodeValueRF), []subtest{ { "wrong type", @@ -838,7 +886,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "NullEncodeValue", - ValueEncoderFunc(nullEncodeValue), + nil, + reflectFreeValueEncoderFunc(nullEncodeValueRF), []subtest{ { "wrong type", @@ -853,7 +902,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "RegexEncodeValue", - ValueEncoderFunc(regexEncodeValue), + nil, + reflectFreeValueEncoderFunc(regexEncodeValueRF), []subtest{ { "wrong type", @@ -868,7 +918,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DBPointerEncodeValue", - ValueEncoderFunc(dbPointerEncodeValue), + nil, + reflectFreeValueEncoderFunc(dbPointerEncodeValueRF), []subtest{ { "wrong type", @@ -890,7 +941,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimestampEncodeValue", - ValueEncoderFunc(timestampEncodeValue), + nil, + reflectFreeValueEncoderFunc(timestampEncodeValueRF), []subtest{ { "wrong type", @@ -905,7 +957,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MinKeyEncodeValue", - ValueEncoderFunc(minKeyEncodeValue), + nil, + reflectFreeValueEncoderFunc(minKeyEncodeValueRF), []subtest{ { "wrong type", @@ -920,7 +973,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MaxKeyEncodeValue", - ValueEncoderFunc(maxKeyEncodeValue), + nil, + reflectFreeValueEncoderFunc(maxKeyEncodeValueRF), []subtest{ { "wrong type", @@ -935,7 +989,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreDocumentEncodeValue", - ValueEncoderFunc(coreDocumentEncodeValue), + nil, + reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF), []subtest{ { "wrong type", @@ -994,6 +1049,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "StructEncodeValue", newStructCodec(&mapCodec{}), + nil, []subtest{ { "interface value", @@ -1015,7 +1071,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CodeWithScopeEncodeValue", - ValueEncoderFunc(codeWithScopeEncodeValue), + nil, + reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF), []subtest{ { "wrong type", @@ -1050,7 +1107,8 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreArrayEncodeValue", - &arrayCodec{}, + nil, + reflectFreeValueEncoderFunc(coreArrayEncodeValueRF), []subtest{ { "wrong type", @@ -1110,13 +1168,27 @@ func TestDefaultValueEncoders(t *testing.T) { llvrw = subtest.llvrw } llvrw.T = t - err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) - if !assert.CompareErrors(err, subtest.err) { - t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + + if tc.ve != nil { + err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) + if !assert.CompareErrors(err, subtest.err) { + t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, subtest.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + } } - invoked := llvrw.invoked - if !cmp.Equal(invoked, subtest.invoke) { - t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + + if tc.rfve != nil { + err := tc.rfve.EncodeValue(ec, llvrw, subtest.val) + if !assert.CompareErrors(err, subtest.err) { + t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, subtest.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + } } }) } @@ -1719,7 +1791,7 @@ func TestDefaultValueEncoders(t *testing.T) { t.Run("EmptyInterfaceEncodeValue/nil", func(t *testing.T) { val := reflect.New(tEmpty).Elem() llvrw := new(valueReaderWriter) - err := (&emptyInterfaceCodec{}).EncodeValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) + err := emptyInterfaceValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) noerr(t, err) if llvrw.invoked != writeNull { t.Errorf("Incorrect method called. got %v; want %v", llvrw.invoked, writeNull) @@ -1730,10 +1802,10 @@ func TestDefaultValueEncoders(t *testing.T) { val := reflect.New(tEmpty).Elem() val.Set(reflect.ValueOf(int64(1234567890))) llvrw := new(valueReaderWriter) - got := (&emptyInterfaceCodec{}).EncodeValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) + err := emptyInterfaceValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) want := errNoEncoder{Type: tInt64} - if !assert.CompareErrors(got, want) { - t.Errorf("Did not receive expected error. got %v; want %v", got, want) + if !assert.CompareErrors(err, want) { + t.Errorf("Did not receive expected error. got %v; want %v", err, want) } }) } diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 80d44d8c66..b669dfd8d9 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -17,28 +17,6 @@ type emptyInterfaceCodec struct { decodeBinaryAsSlice bool } -// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it -// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a -// collection. -var _ typeDecoder = &emptyInterfaceCodec{} - -// EncodeValue is the ValueEncoderFunc for interface{}. -func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tEmpty { - return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - encoder, err := ec.LookupEncoder(val.Elem().Type()) - if err != nil { - return err - } - - return encoder.EncodeValue(ec, vw, val.Elem()) -} - func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument if isDocument { diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index f42935e5d8..6e11353168 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -38,11 +38,12 @@ func NewMgoRegistry() *Registry { uintCodec := &uintCodec{encodeToMinSize: true} reg := NewRegistry() + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(true)) + reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}) reg.RegisterKindDecoder(reflect.String, ValueDecoderFunc(mgoStringDecodeValue)) reg.RegisterKindDecoder(reflect.Struct, structCodec) reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Struct, structCodec) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Map, mapCodec) @@ -69,8 +70,9 @@ func NewRespectNilValuesMgoRegistry() *Registry { } reg := NewMgoRegistry() + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) + reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) reg.RegisterKindEncoder(reflect.Map, mapCodec) return reg diff --git a/bson/registry.go b/bson/registry.go index d8f65ddc0d..dbe4595e10 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -88,6 +88,8 @@ type Registry struct { kindEncoders *kindEncoderCache kindDecoders *kindDecoderCache typeMap sync.Map // map[Type]reflect.Type + + reflectFreeTypeEncoders *reflectFreeTypeEncoderCache } // NewRegistry creates a new empty Registry. @@ -97,6 +99,8 @@ func NewRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -118,6 +122,10 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) r.typeEncoders.Store(valueType, enc) } +func (r *Registry) registerReflectFreeTypeEncoder(valueType reflect.Type, enc reflectFreeValueEncoder) { + r.reflectFreeTypeEncoders.Store(valueType, enc) +} + // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. // // The type will be used as provided, so a decoder can be registered for a type and a different @@ -244,16 +252,26 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if valueType == nil { return nil, errNoEncoder{Type: valueType} } - enc, found := r.lookupTypeEncoder(valueType) - if found { + + // First attempt to get a user-defined type encoder. + if enc, found := r.lookupTypeEncoder(valueType); found { if enc == nil { return nil, errNoEncoder{Type: valueType} } + return enc, nil } - enc, found = r.lookupInterfaceEncoder(valueType, true) - if found { + // Next try to get a reflection-free encoder. + if rfeEnc, found := r.reflectFreeTypeEncoders.Load(valueType); found && rfeEnc != nil { + wrapper := func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return rfeEnc.EncodeValue(ec, vw, val.Interface()) + } + + return ValueEncoderFunc(wrapper), nil + } + + if enc, found := r.lookupInterfaceEncoder(valueType, true); found { return r.typeEncoders.LoadOrStore(valueType, enc), nil } diff --git a/bson/registry_test.go b/bson/registry_test.go index ea7b2b2ef7..fd603cd66e 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -22,6 +22,8 @@ func newTestRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } } diff --git a/bson/time_codec.go b/bson/time_codec.go index 1c00374c19..32be418a6b 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -97,13 +97,3 @@ func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V val.Set(elem) return nil } - -// EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTime { - return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} - } - tt := val.Interface().(time.Time) - dt := NewDateTimeFromTime(tt) - return vw.WriteDateTime(int64(dt)) -}