From 7e17312ebf20c5b525b9dd07f0386de78249b546 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Wed, 9 Apr 2025 15:55:05 -0600 Subject: [PATCH 01/13] GODRIVER-3455 Update RF type encoders --- bson/array_codec.go | 8 +- bson/bsoncodec.go | 19 ++ bson/byte_slice_codec.go | 11 - bson/codec_cache.go | 67 ++++++ bson/default_value_decoders_test.go | 18 ++ bson/default_value_encoders.go | 352 ++++++++++++++++++++++++++-- bson/default_value_encoders_test.go | 187 +++++++-------- bson/mgoregistry.go | 8 +- bson/registry.go | 171 ++++++++++++++ bson/registry_test.go | 3 + bson/time_codec.go | 11 +- 11 files changed, 719 insertions(+), 136 deletions(-) diff --git a/bson/array_codec.go b/bson/array_codec.go index 4642fb6ea2..aa65803959 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -16,12 +16,12 @@ import ( type arrayCodec struct{} // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCoreArray { - return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} +func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val any) error { + arr, ok := val.(bsoncore.Array) + if !ok { + return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: reflect.ValueOf(val)} } - arr := val.Interface().(bsoncore.Array) return copyArrayFromBytes(vw, arr) } diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 80e13e7d81..1a8ca03104 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -139,6 +139,25 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref return fn(ec, vw, val) } +// defaultValueEncoderFunc is an adapter function that allows a function with +// the correct signature to be used as a ValueEncoder. +type defaultValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error + +// EncodeValue implements the ValueEncoder interface. +func (fn defaultValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return fn(ec, vw, val) +} + +type reflectFreeValueEncoder interface { + EncodeValue(ec EncodeContext, vw ValueWriter, val any) error +} + +type reflectFreeValueEncoderFunc func(ec EncodeContext, vw ValueWriter, val any) error + +func (fn reflectFreeValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val any) error { + return fn(ec, vw, val) +} + // ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. // Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc, // ValueDecoderFunc is provided to allow the use of a function with the correct signature as a diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index bd44cf9a89..ce0ad976ec 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -23,17 +23,6 @@ type byteSliceCodec struct { // collection. var _ typeDecoder = &byteSliceCodec{} -// EncodeValue is the ValueEncoder for []byte. -func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tByteSlice { - return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { - return vw.WriteNull() - } - return vw.WriteBinary(val.Interface().([]byte)) -} - func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ diff --git a/bson/codec_cache.go b/bson/codec_cache.go index b4042822e6..472d11ce23 100644 --- a/bson/codec_cache.go +++ b/bson/codec_cache.go @@ -24,6 +24,7 @@ func init() { // statically assert array size var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer] var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer] +var _ = (kindEncoderReflectFreeCache{}).entries[reflect.UnsafePointer] type typeEncoderCache struct { cache sync.Map // map[reflect.Type]ValueEncoder @@ -58,6 +59,39 @@ func (c *typeEncoderCache) Clone() *typeEncoderCache { return cc } +type typeReflectFreeEncoderCache struct { + cache sync.Map // map[reflect.Type]typeReflectFreeEncoderCache +} + +func (c *typeReflectFreeEncoderCache) Store(rt reflect.Type, enc reflectFreeValueEncoder) { + c.cache.Store(rt, enc) +} + +func (c *typeReflectFreeEncoderCache) Load(rt reflect.Type) (reflectFreeValueEncoder, bool) { + if v, _ := c.cache.Load(rt); v != nil { + return v.(reflectFreeValueEncoder), true + } + return nil, false +} + +func (c *typeReflectFreeEncoderCache) LoadOrStore(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { + if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { + enc = v.(reflectFreeValueEncoder) + } + return enc +} + +func (c *typeReflectFreeEncoderCache) Clone() *typeReflectFreeEncoderCache { + cc := new(typeReflectFreeEncoderCache) + c.cache.Range(func(k, v interface{}) bool { + if k != nil && v != nil { + cc.cache.Store(k, v) + } + return true + }) + return cc +} + type typeDecoderCache struct { cache sync.Map // map[reflect.Type]ValueDecoder } @@ -95,6 +129,39 @@ func (c *typeDecoderCache) Clone() *typeDecoderCache { // so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type // is always the same (since different concrete types may implement the // ValueEncoder interface). +type kindEncoderReflectFreeCacheEntry struct { + enc reflectFreeValueEncoder +} + +type kindEncoderReflectFreeCache struct { + entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry +} + +func (c *kindEncoderReflectFreeCache) Store(rt reflect.Kind, enc reflectFreeValueEncoder) { + if enc != nil && rt < reflect.Kind(len(c.entries)) { + c.entries[rt].Store(&kindEncoderReflectFreeCacheEntry{enc: enc}) + } +} + +func (c *kindEncoderReflectFreeCache) Load(rt reflect.Kind) (reflectFreeValueEncoder, bool) { + if rt < reflect.Kind(len(c.entries)) { + if ent, ok := c.entries[rt].Load().(*kindEncoderReflectFreeCacheEntry); ok { + return ent.enc, ent.enc != nil + } + } + return nil, false +} + +func (c *kindEncoderReflectFreeCache) Clone() *kindEncoderReflectFreeCache { + cc := new(kindEncoderReflectFreeCache) + for i, v := range c.entries { + if val := v.Load(); val != nil { + cc.entries[i].Store(val) + } + } + return cc +} + type kindEncoderCacheEntry struct { enc ValueEncoder } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 4dad538a26..43379d1d14 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -3418,6 +3418,9 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(topLevelReg) registerDefaultDecoders(topLevelReg) @@ -3428,6 +3431,9 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(embeddedReg) registerDefaultDecoders(embeddedReg) @@ -3474,6 +3480,9 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -3568,6 +3577,9 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultDecoders(nestedRegistry) nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) @@ -3725,6 +3737,9 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultDecoders(reg) reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) @@ -3799,6 +3814,9 @@ func buildDefaultRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index bd5a20f2f9..c20b28dc89 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -13,6 +13,7 @@ import ( "net/url" "reflect" "sync" + "time" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -59,28 +60,58 @@ func registerDefaultEncoders(reg *Registry) { mapEncoder := &mapCodec{} uintCodec := &uintCodec{} - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) - reg.RegisterTypeEncoder(tTime, &timeCodec{}) - reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) - reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) - reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) - reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) - reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) - reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) - reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) - reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) - reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue)) - reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) - reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) - reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) - reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) - reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) - reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) - reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) - reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) - reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) - reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) + // Register the reflect-free default type encoders. + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) + reg.registerReflectFreeTypeEncoder(tTime, reflectFreeValueEncoderFunc(timeEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCoreArray, reflectFreeValueEncoderFunc(coreArrayEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tNull, reflectFreeValueEncoderFunc(nullEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tOID, reflectFreeValueEncoderFunc(objectIDEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tDecimal, reflectFreeValueEncoderFunc(decimal128EncodeValueX)) + reg.registerReflectFreeTypeEncoder(tJSONNumber, reflectFreeValueEncoderFunc(jsonNumberEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tURL, reflectFreeValueEncoderFunc(urlEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tJavaScript, reflectFreeValueEncoderFunc(javaScriptEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tSymbol, reflectFreeValueEncoderFunc(symbolEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tBinary, reflectFreeValueEncoderFunc(binaryEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tVector, reflectFreeValueEncoderFunc(vectorEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tUndefined, reflectFreeValueEncoderFunc(undefinedEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tDateTime, reflectFreeValueEncoderFunc(dateTimeEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tRegex, reflectFreeValueEncoderFunc(regexEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tDBPointer, reflectFreeValueEncoderFunc(dbPointerEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tTimestamp, reflectFreeValueEncoderFunc(timestampEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tMinKey, reflectFreeValueEncoderFunc(minKeyEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tMaxKey, reflectFreeValueEncoderFunc(maxKeyEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tCoreDocument, reflectFreeValueEncoderFunc(coreDocumentEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tCodeWithScope, reflectFreeValueEncoderFunc(codeWithScopeEncodeValueX)) + + // Register the reflect-based default encoders. These are required since + // removing them would break Registry.LookupEncoder. However, these will + // never be used internally. + // + reg.RegisterTypeEncoder(tByteSlice, byteSliceEncodeValue(false)) + reg.RegisterTypeEncoder(tTime, defaultValueEncoderFunc(timeEncodeValue)) + reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) // TODO: extend this to reflection free + reg.RegisterTypeEncoder(tCoreArray, defaultValueEncoderFunc(coreArrayEncodeValue)) + reg.RegisterTypeEncoder(tOID, defaultValueEncoderFunc(objectIDEncodeValue)) + reg.RegisterTypeEncoder(tDecimal, defaultValueEncoderFunc(decimal128EncodeValue)) + reg.RegisterTypeEncoder(tJSONNumber, defaultValueEncoderFunc(jsonNumberEncodeValue)) + reg.RegisterTypeEncoder(tURL, defaultValueEncoderFunc(urlEncodeValue)) + reg.RegisterTypeEncoder(tJavaScript, defaultValueEncoderFunc(javaScriptEncodeValue)) + reg.RegisterTypeEncoder(tSymbol, defaultValueEncoderFunc(symbolEncodeValue)) + reg.RegisterTypeEncoder(tBinary, defaultValueEncoderFunc(binaryEncodeValue)) + reg.RegisterTypeEncoder(tVector, defaultValueEncoderFunc(vectorEncodeValue)) + reg.RegisterTypeEncoder(tUndefined, defaultValueEncoderFunc(undefinedEncodeValue)) + reg.RegisterTypeEncoder(tDateTime, defaultValueEncoderFunc(dateTimeEncodeValue)) + reg.RegisterTypeEncoder(tNull, defaultValueEncoderFunc(nullEncodeValue)) + reg.RegisterTypeEncoder(tRegex, defaultValueEncoderFunc(regexEncodeValue)) + reg.RegisterTypeEncoder(tDBPointer, defaultValueEncoderFunc(dbPointerEncodeValue)) + reg.RegisterTypeEncoder(tTimestamp, defaultValueEncoderFunc(timestampEncodeValue)) + reg.RegisterTypeEncoder(tMinKey, defaultValueEncoderFunc(minKeyEncodeValue)) + reg.RegisterTypeEncoder(tMaxKey, defaultValueEncoderFunc(maxKeyEncodeValue)) + reg.RegisterTypeEncoder(tCoreDocument, defaultValueEncoderFunc(coreDocumentEncodeValue)) + reg.RegisterTypeEncoder(tCodeWithScope, defaultValueEncoderFunc(codeWithScopeEncodeValue)) + + // Register the kind-based default encoders. These must continue using + // reflection since they account for custom types that cannot be anticipated. reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue)) reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue)) @@ -100,6 +131,8 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterKindEncoder(reflect.String, &stringCodec{}) reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder)) reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{}) + + // Register the interface-based default encoders. reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) } @@ -152,6 +185,15 @@ func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} } +func floatEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + switch val := val.(type) { + case float32, float64: + return vw.WriteDouble(val.(float64)) + } + + return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: reflect.ValueOf(val)} +} + // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { @@ -160,6 +202,20 @@ func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) err return vw.WriteObjectID(val.Interface().(ObjectID)) } +// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. +func objectIDEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + objID, ok := val.(ObjectID) + if !ok { + return ValueEncoderError{ + Name: "ObjectIDEncodeValue", + Types: []reflect.Type{tOID}, + Received: reflect.ValueOf(val), + } + } + + return vw.WriteObjectID(objID) +} + // decimal128EncodeValue is the ValueEncoderFunc for Decimal128. func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { @@ -168,6 +224,19 @@ func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) e return vw.WriteDecimal128(val.Interface().(Decimal128)) } +func decimal128EncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + d128, ok := val.(Decimal128) + if !ok { + return ValueEncoderError{ + Name: "Decimal128EncodeValue", + Types: []reflect.Type{tDecimal}, + Received: reflect.ValueOf(val), + } + } + + return vw.WriteDecimal128(d128) +} + // jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { @@ -188,6 +257,29 @@ func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) return floatEncodeValue(ec, vw, reflect.ValueOf(f64)) } +func jsonNumberEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error { + //if !val.IsValid() || val.Type() != tJSONNumber { + // return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} + //} + //jsnum := val.Interface().(json.Number) + jsnum, ok := val.(json.Number) + if !ok { + return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: reflect.ValueOf(val)} + } + + // Attempt int first, then float64 + if i64, err := jsnum.Int64(); err == nil { + return intEncodeValue(ec, vw, reflect.ValueOf(i64)) + } + + f64, err := jsnum.Float64() + if err != nil { + return err + } + + return floatEncodeValueX(ec, vw, f64) +} + // urlEncodeValue is the ValueEncoderFunc for url.URL. func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tURL { @@ -197,6 +289,15 @@ func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { return vw.WriteString(u.String()) } +func urlEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + u, ok := val.(url.URL) + if !ok { + return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: reflect.ValueOf(val)} + } + + return vw.WriteString(u.String()) +} + // arrayEncodeValue is the ValueEncoderFunc for array types. func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { @@ -346,6 +447,15 @@ func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) e return vw.WriteJavascript(val.String()) } +func javaScriptEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + jsString, ok := val.(JavaScript) + if !ok { + return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: reflect.ValueOf(val)} + } + + return vw.WriteJavascript(string(jsString)) +} + // symbolEncodeValue is the ValueEncoderFunc for the Symbol type. func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { @@ -355,6 +465,15 @@ func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteSymbol(val.String()) } +func symbolEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + symbol, ok := val.(Symbol) + if !ok { + return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: reflect.ValueOf(val)} + } + + return vw.WriteSymbol(string(symbol)) +} + // binaryEncodeValue is the ValueEncoderFunc for Binary. func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { @@ -365,6 +484,15 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } +func binaryEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + b, ok := val.(Binary) + if !ok { + return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: reflect.ValueOf(val)} + } + + return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +} + // vectorEncodeValue is the ValueEncoderFunc for Vector. func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { t := val.Type() @@ -379,6 +507,16 @@ func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } +func vectorEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + v, ok := val.(Vector) + if !ok { + return ValueEncoderError{Name: "VectorEncodeValue", Types: []reflect.Type{tVector}, Received: reflect.ValueOf(val)} + } + + b := v.Binary() + return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +} + // undefinedEncodeValue is the ValueEncoderFunc for Undefined. func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { @@ -388,6 +526,14 @@ func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er return vw.WriteUndefined() } +func undefinedEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(Undefined); !ok { + return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: reflect.ValueOf(val)} + } + + return vw.WriteUndefined() +} + // dateTimeEncodeValue is the ValueEncoderFunc for DateTime. func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { @@ -397,6 +543,15 @@ func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) err return vw.WriteDateTime(val.Int()) } +func dateTimeEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + dateTime, ok := val.(DateTime) + if !ok { + return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: reflect.ValueOf(val)} + } + + return vw.WriteDateTime(int64(dateTime)) +} + // nullEncodeValue is the ValueEncoderFunc for Null. func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { @@ -406,6 +561,18 @@ func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { return vw.WriteNull() } +func nullEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(Null); !ok { + return ValueEncoderError{ + Name: "NullEncodeValue", + Types: []reflect.Type{tNull}, + Received: reflect.ValueOf(val), + } + } + + return vw.WriteNull() +} + // regexEncodeValue is the ValueEncoderFunc for Regex. func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { @@ -417,6 +584,15 @@ func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteRegex(regex.Pattern, regex.Options) } +func regexEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + regex, ok := val.(Regex) + if !ok { + return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: reflect.ValueOf(val)} + } + + return vw.WriteRegex(regex.Pattern, regex.Options) +} + // dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { @@ -428,6 +604,15 @@ func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } +func dbPointerEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + dbp, ok := val.(DBPointer) + if !ok { + return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: reflect.ValueOf(val)} + } + + return vw.WriteDBPointer(dbp.DB, dbp.Pointer) +} + // timestampEncodeValue is the ValueEncoderFunc for Timestamp. func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { @@ -439,6 +624,15 @@ func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er return vw.WriteTimestamp(ts.T, ts.I) } +func timestampEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + ts, ok := val.(Timestamp) + if !ok { + return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: reflect.ValueOf(val)} + } + + return vw.WriteTimestamp(ts.T, ts.I) +} + // minKeyEncodeValue is the ValueEncoderFunc for MinKey. func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { @@ -448,6 +642,14 @@ func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteMinKey() } +func minKeyEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(MinKey); !ok { + return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: reflect.ValueOf(val)} + } + + return vw.WriteMinKey() +} + // maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { @@ -457,6 +659,14 @@ func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error return vw.WriteMaxKey() } +func maxKeyEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + if _, ok := val.(MaxKey); !ok { + return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: reflect.ValueOf(val)} + } + + return vw.WriteMaxKey() +} + // coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { @@ -468,6 +678,15 @@ func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) return copyDocumentFromBytes(vw, cdoc) } +func coreDocumentEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { + cdoc, ok := val.(bsoncore.Document) + if !ok { + return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: reflect.ValueOf(val)} + } + + return copyDocumentFromBytes(vw, cdoc) +} + // codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { @@ -507,6 +726,43 @@ func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Valu return dw.WriteDocumentEnd() } +func codeWithScopeEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error { + cws, ok := val.(CodeWithScope) + if !ok { + return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: reflect.ValueOf(val)} + } + + dw, err := vw.WriteCodeWithScope(string(cws.Code)) + if err != nil { + return err + } + + sw := sliceWriterPool.Get().(*sliceWriter) + defer sliceWriterPool.Put(sw) + *sw = (*sw)[:0] + + scopeVW := bvwPool.Get().(*valueWriter) + scopeVW.reset(scopeVW.buf[:0]) + scopeVW.w = sw + defer bvwPool.Put(scopeVW) + + encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) + if err != nil { + return err + } + + err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope)) + if err != nil { + return err + } + + err = copyBytesToDocumentWriter(dw, *sw) + if err != nil { + return err + } + return dw.WriteDocumentEnd() +} + // isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type func isImplementationNil(val reflect.Value, inter reflect.Type) bool { vt := val.Type() @@ -515,3 +771,55 @@ func isImplementationNil(val reflect.Value, inter reflect.Type) bool { } return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() } + +func byteSliceEncodeValueRF(encodeNilAsEmpty bool) reflectFreeValueEncoderFunc { + return reflectFreeValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val any) error { + byteSlice, ok := val.([]byte) + if !ok { + return ValueEncoderError{ + Name: "ByteSliceEncodeValue", + Types: []reflect.Type{tByteSlice}, + Received: reflect.ValueOf(val), + } + } + + if byteSlice == nil && !encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { + return vw.WriteNull() + } + + return vw.WriteBinary(byteSlice) + }) +} + +func byteSliceEncodeValue(encodeNilAsEmpty bool) defaultValueEncoderFunc { + return defaultValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return byteSliceEncodeValueRF(encodeNilAsEmpty)(ec, vw, val.Interface()) + }) +} + +func timeEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { + tt, ok := val.(time.Time) + if !ok { + return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(val)} + } + + dt := NewDateTimeFromTime(tt) + return vw.WriteDateTime(int64(dt)) +} + +func timeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return timeEncodeValueRF(ec, vw, val.Interface()) +} + +func coreArrayEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { + arr, ok := val.(bsoncore.Array) + if !ok { + return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: reflect.ValueOf(val)} + } + + return copyArrayFromBytes(vw, arr) +} + +func coreArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return coreArrayEncodeValueRF(ec, vw, val.Interface()) +} diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index e15019785d..d2eb0c364a 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -51,7 +51,7 @@ func TestDefaultValueEncoders(t *testing.T) { type myfloat32 float32 type myfloat64 float64 - now := time.Now().Truncate(time.Millisecond) + //now := time.Now().Truncate(time.Millisecond) pjsnum := new(json.Number) *pjsnum = json.Number("3.14159") d128 := NewDecimal128(12345, 67890) @@ -194,21 +194,21 @@ func TestDefaultValueEncoders(t *testing.T) { {"float64/reflection path", myfloat64(3.14159), nil, nil, writeDouble, nil}, }, }, - { - "TimeEncodeValue", - &timeCodec{}, - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(wrong)}, - }, - {"time.Time", now, nil, nil, writeDateTime, nil}, - }, - }, + //{ + // "TimeEncodeValue", + // &timeCodec{}, + // []subtest{ + // { + // "wrong type", + // wrong, + // nil, + // nil, + // nothing, + // ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(wrong)}, + // }, + // {"time.Time", now, nil, nil, writeDateTime, nil}, + // }, + //}, { "MapEncodeValue", &mapCodec{}, @@ -531,36 +531,39 @@ func TestDefaultValueEncoders(t *testing.T) { {"url.URL", url.URL{Scheme: "http", Host: "example.com"}, nil, nil, writeString, nil}, }, }, - { - "ByteSliceEncodeValue", - &byteSliceCodec{}, - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: reflect.ValueOf(wrong)}, - }, - {"[]byte", []byte{0x01, 0x02, 0x03}, nil, nil, writeBinary, nil}, - {"[]byte/nil", []byte(nil), nil, nil, writeNull, nil}, - }, - }, - { - "EmptyInterfaceEncodeValue", - &emptyInterfaceCodec{}, - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(wrong)}, - }, - }, - }, + //{ + // "ByteSliceEncodeValue", + // // TODO: Fix this. + // ValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + // return (&byteSliceCodec{}).EncodeValue(ec, vw, val) + // }), + // []subtest{ + // { + // "wrong type", + // wrong, + // nil, + // nil, + // nothing, + // ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: reflect.ValueOf(wrong)}, + // }, + // {"[]byte", []byte{0x01, 0x02, 0x03}, nil, nil, writeBinary, nil}, + // {"[]byte/nil", []byte(nil), nil, nil, writeNull, nil}, + // }, + //}, + //{ + // "EmptyInterfaceEncodeValue", + // &emptyInterfaceCodec{}, + // []subtest{ + // { + // "wrong type", + // wrong, + // nil, + // nil, + // nothing, + // ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(wrong)}, + // }, + // }, + //}, { "ValueMarshalerEncodeValue", ValueEncoderFunc(valueMarshalerEncodeValue), @@ -1048,53 +1051,53 @@ func TestDefaultValueEncoders(t *testing.T) { }, }, }, - { - "CoreArrayEncodeValue", - &arrayCodec{}, - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{ - Name: "CoreArrayEncodeValue", - Types: []reflect.Type{tCoreArray}, - Received: reflect.ValueOf(wrong), - }, - }, + //{ + // "CoreArrayEncodeValue", + // &arrayCodec{}, + // []subtest{ + // { + // "wrong type", + // wrong, + // nil, + // nil, + // nothing, + // ValueEncoderError{ + // Name: "CoreArrayEncodeValue", + // Types: []reflect.Type{tCoreArray}, + // Received: reflect.ValueOf(wrong), + // }, + // }, - { - "WriteArray Error", - bsoncore.Array{}, - nil, - &valueReaderWriter{Err: errors.New("wa error"), ErrAfter: writeArray}, - writeArray, - errors.New("wa error"), - }, - { - "WriteArrayElement Error", - bsoncore.Array(buildDocumentArray(func([]byte) []byte { - return bsoncore.AppendNullElement(nil, "foo") - })), - nil, - &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, - writeArrayElement, - errors.New("wae error"), - }, - { - "encodeValue error", - bsoncore.Array(buildDocumentArray(func([]byte) []byte { - return bsoncore.AppendNullElement(nil, "foo") - })), - nil, - &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeNull}, - writeNull, - errors.New("ev error"), - }, - }, - }, + // { + // "WriteArray Error", + // bsoncore.Array{}, + // nil, + // &valueReaderWriter{Err: errors.New("wa error"), ErrAfter: writeArray}, + // writeArray, + // errors.New("wa error"), + // }, + // { + // "WriteArrayElement Error", + // bsoncore.Array(buildDocumentArray(func([]byte) []byte { + // return bsoncore.AppendNullElement(nil, "foo") + // })), + // nil, + // &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, + // writeArrayElement, + // errors.New("wae error"), + // }, + // { + // "encodeValue error", + // bsoncore.Array(buildDocumentArray(func([]byte) []byte { + // return bsoncore.AppendNullElement(nil, "foo") + // })), + // nil, + // &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeNull}, + // writeNull, + // errors.New("ev error"), + // }, + // }, + //}, } for _, tc := range testCases { diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index f42935e5d8..a6b12328d8 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -38,11 +38,13 @@ func NewMgoRegistry() *Registry { uintCodec := &uintCodec{encodeToMinSize: true} reg := NewRegistry() + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(true)) + reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}) reg.RegisterKindDecoder(reflect.String, ValueDecoderFunc(mgoStringDecodeValue)) reg.RegisterKindDecoder(reflect.Struct, structCodec) reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) + //reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Struct, structCodec) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Map, mapCodec) @@ -69,8 +71,10 @@ func NewRespectNilValuesMgoRegistry() *Registry { } reg := NewMgoRegistry() + reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) + reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) + //reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) reg.RegisterKindEncoder(reflect.Map, mapCodec) return reg diff --git a/bson/registry.go b/bson/registry.go index d8f65ddc0d..53952c94b7 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -7,10 +7,15 @@ package bson import ( + "encoding/json" "errors" "fmt" + "net/url" "reflect" "sync" + "time" + + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) // defaultRegistry is the default Registry. It contains the default codecs and the @@ -88,6 +93,9 @@ type Registry struct { kindEncoders *kindEncoderCache kindDecoders *kindDecoderCache typeMap sync.Map // map[Type]reflect.Type + + reflectFreeTypeEncoders *typeReflectFreeEncoderCache + reflectFreeKindEncoders *kindEncoderReflectFreeCache } // NewRegistry creates a new empty Registry. @@ -97,6 +105,9 @@ func NewRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -118,6 +129,18 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) r.typeEncoders.Store(valueType, enc) } +func (r *Registry) registerReflectFreeTypeEncoder(valueType reflect.Type, enc reflectFreeValueEncoder) { + r.reflectFreeTypeEncoders.Store(valueType, enc) +} + +func (r *Registry) registerReflectFreeKindEncoder(kind reflect.Kind, enc reflectFreeValueEncoder) { + r.reflectFreeKindEncoders.Store(kind, enc) +} + +func (r *Registry) storeReflectFreeTypeEncoder(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { + return r.reflectFreeTypeEncoders.LoadOrStore(rt, enc) +} + // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. // // The type will be used as provided, so a decoder can be registered for a type and a different @@ -223,6 +246,130 @@ func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { r.typeMap.Store(bt, rt) } +func getReflectTypeFromAny(val any) (reflect.Type, error) { + switch v := val.(type) { + case bool: + return tBool, nil + case float64: + return tFloat64, nil + case int32: + return tInt32, nil + case int64: + return tInt64, nil + case string: + return tString, nil + case time.Time: + return tTime, nil + case interface{}: + return tEmpty, nil + case []byte: + return tByteSlice, nil + case byte: + return tByte, nil + case url.URL: + return tURL, nil + case json.Number: + return tJSONNumber, nil + case ValueMarshaler: + return tValueMarshaler, nil + case ValueUnmarshaler: + return tValueUnmarshaler, nil + case Marshaler: + return tMarshaler, nil + case Unmarshaler: + return tUnmarshaler, nil + case Zeroer: + return tZeroer, nil + case Binary: + return tBinary, nil + case Undefined: + return tUndefined, nil + case ObjectID: + return tOID, nil + case DateTime: + return tDateTime, nil + case Null: + return tNull, nil + case Regex: + return tRegex, nil + case CodeWithScope: + return tCodeWithScope, nil + case DBPointer: + return tDBPointer, nil + case JavaScript: + return tJavaScript, nil + case Symbol: + return tSymbol, nil + case Timestamp: + return tTimestamp, nil + case Decimal128: + return tDecimal, nil + case Vector: + return tVector, nil + case MinKey: + return tMinKey, nil + case MaxKey: + return tMaxKey, nil + case D: + return tD, nil + case A: + return tA, nil + case E: + return tE, nil + case bsoncore.Document: + return tCoreDocument, nil + case bsoncore.Array: + return tCoreArray, nil + default: + return nil, fmt.Errorf("no default encoder for type %T", v) + } +} + +//func lookupReflectFreeEncoder(r *Registry, typ reflect.Type) (reflectFreeValueEncoder, error) { +//} + +func lookupEncoderReflectFree(r *Registry, typ reflect.Type, val any) (reflectFreeValueEncoder, error) { + if typ == nil { + var err error + + typ, err = getReflectTypeFromAny(val) + if err != nil { + return nil, err + } + } + + rfeEnc, found := r.reflectFreeTypeEncoders.Load(typ) + if !found { + return nil, errNoEncoder{Type: typ} + } + + return rfeEnc, nil +} + +func lookupUserDefinedEncoder(r *Registry, val any) (ValueEncoder, reflect.Type, bool, error) { + typ, err := getReflectTypeFromAny(val) + if err != nil { + return nil, nil, false, err + } + + enc, found := r.lookupTypeEncoder(typ) + if found { + if enc == nil { + return nil, typ, false, errNoEncoder{Type: typ} + } + + // We do not use ValueEncoder in the default case, preferring a reflect-free + // solution. + if _, ok := enc.(defaultValueEncoderFunc); ok { + return nil, typ, false, nil + } + + return enc, typ, true, nil + } + + return nil, typ, false, nil +} + // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup // order: // @@ -244,6 +391,21 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if valueType == nil { return nil, errNoEncoder{Type: valueType} } + + // First attempt to lookup a reflect-free default encoder. + // TODO: This will be moved in favor of the lookup* solution. + rfeEnc, found := r.reflectFreeTypeEncoders.Load(valueType) + if found { + if rfeEnc != nil { + wrapper := func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return rfeEnc.EncodeValue(ec, vw, val.Interface()) + } + + return ValueEncoderFunc(wrapper), nil + } + } + + // Then lookup a user-defined encoder. enc, found := r.lookupTypeEncoder(valueType) if found { if enc == nil { @@ -257,6 +419,15 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { return r.typeEncoders.LoadOrStore(valueType, enc), nil } + if v, ok := r.reflectFreeKindEncoders.Load(valueType.Kind()); ok { + ve := r.storeReflectFreeTypeEncoder(valueType, v) + wrapper := func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return ve.EncodeValue(ec, vw, val.Interface()) + } + + return ValueEncoderFunc(wrapper), nil + } + if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { return r.storeTypeEncoder(valueType, v), nil } diff --git a/bson/registry_test.go b/bson/registry_test.go index ea7b2b2ef7..2d05c6a22a 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -22,6 +22,9 @@ func newTestRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), + + reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } } diff --git a/bson/time_codec.go b/bson/time_codec.go index 1c00374c19..85d37496fc 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -99,11 +99,12 @@ func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTime { - return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} +func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val any) error { + timeVal, ok := val.(time.Time) + if !ok { + return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(val)} } - tt := val.Interface().(time.Time) - dt := NewDateTimeFromTime(tt) + + dt := NewDateTimeFromTime(timeVal) return vw.WriteDateTime(int64(dt)) } From 760fe0570fff2ffaa1dd32c559c7631ff5e1360e Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 10 Apr 2025 14:24:14 -0600 Subject: [PATCH 02/13] GODRIVER-3455 Order LookupEncoder correctly --- bson/default_value_encoders.go | 25 ++++-- bson/mgoregistry.go | 2 - bson/registry.go | 160 +++------------------------------ 3 files changed, 32 insertions(+), 155 deletions(-) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index c20b28dc89..019e23682c 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -63,6 +63,7 @@ func registerDefaultEncoders(reg *Registry) { // Register the reflect-free default type encoders. reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) reg.registerReflectFreeTypeEncoder(tTime, reflectFreeValueEncoderFunc(timeEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tEmpty, reflectFreeValueEncoderFunc(emptyInterfaceValueRF)) reg.registerReflectFreeTypeEncoder(tCoreArray, reflectFreeValueEncoderFunc(coreArrayEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tNull, reflectFreeValueEncoderFunc(nullEncodeValueX)) reg.registerReflectFreeTypeEncoder(tOID, reflectFreeValueEncoderFunc(objectIDEncodeValueX)) @@ -89,7 +90,7 @@ func registerDefaultEncoders(reg *Registry) { // reg.RegisterTypeEncoder(tByteSlice, byteSliceEncodeValue(false)) reg.RegisterTypeEncoder(tTime, defaultValueEncoderFunc(timeEncodeValue)) - reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) // TODO: extend this to reflection free + reg.RegisterTypeEncoder(tEmpty, defaultValueEncoderFunc(emptyInterfaceValue)) // TODO: extend this to reflection free reg.RegisterTypeEncoder(tCoreArray, defaultValueEncoderFunc(coreArrayEncodeValue)) reg.RegisterTypeEncoder(tOID, defaultValueEncoderFunc(objectIDEncodeValue)) reg.RegisterTypeEncoder(tDecimal, defaultValueEncoderFunc(decimal128EncodeValue)) @@ -258,10 +259,6 @@ func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) } func jsonNumberEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error { - //if !val.IsValid() || val.Type() != tJSONNumber { - // return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} - //} - //jsnum := val.Interface().(json.Number) jsnum, ok := val.(json.Number) if !ok { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: reflect.ValueOf(val)} @@ -708,7 +705,6 @@ func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Valu scopeVW.reset(scopeVW.buf[:0]) scopeVW.w = sw defer bvwPool.Put(scopeVW) - encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) if err != nil { return err @@ -745,7 +741,6 @@ func codeWithScopeEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error scopeVW.reset(scopeVW.buf[:0]) scopeVW.w = sw defer bvwPool.Put(scopeVW) - encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) if err != nil { return err @@ -823,3 +818,19 @@ func coreArrayEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { func coreArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { return coreArrayEncodeValueRF(ec, vw, val.Interface()) } + +func emptyInterfaceValueRF(ec EncodeContext, vw ValueWriter, val any) error { + if val == nil { + return vw.WriteNull() + } + encoder, err := ec.LookupEncoder(reflect.TypeOf(val)) + if err != nil { + return err + } + + return encoder.EncodeValue(ec, vw, reflect.ValueOf(val)) +} + +func emptyInterfaceValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return emptyInterfaceValueRF(ec, vw, val.Interface()) +} diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index a6b12328d8..6e11353168 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -44,7 +44,6 @@ func NewMgoRegistry() *Registry { reg.RegisterKindDecoder(reflect.String, ValueDecoderFunc(mgoStringDecodeValue)) reg.RegisterKindDecoder(reflect.Struct, structCodec) reg.RegisterKindDecoder(reflect.Map, mapCodec) - //reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Struct, structCodec) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) reg.RegisterKindEncoder(reflect.Map, mapCodec) @@ -74,7 +73,6 @@ func NewRespectNilValuesMgoRegistry() *Registry { reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) reg.RegisterKindDecoder(reflect.Map, mapCodec) - //reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) reg.RegisterKindEncoder(reflect.Map, mapCodec) return reg diff --git a/bson/registry.go b/bson/registry.go index 53952c94b7..4db6f870b5 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -7,15 +7,10 @@ package bson import ( - "encoding/json" "errors" "fmt" - "net/url" "reflect" "sync" - "time" - - "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) // defaultRegistry is the default Registry. It contains the default codecs and the @@ -246,130 +241,6 @@ func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { r.typeMap.Store(bt, rt) } -func getReflectTypeFromAny(val any) (reflect.Type, error) { - switch v := val.(type) { - case bool: - return tBool, nil - case float64: - return tFloat64, nil - case int32: - return tInt32, nil - case int64: - return tInt64, nil - case string: - return tString, nil - case time.Time: - return tTime, nil - case interface{}: - return tEmpty, nil - case []byte: - return tByteSlice, nil - case byte: - return tByte, nil - case url.URL: - return tURL, nil - case json.Number: - return tJSONNumber, nil - case ValueMarshaler: - return tValueMarshaler, nil - case ValueUnmarshaler: - return tValueUnmarshaler, nil - case Marshaler: - return tMarshaler, nil - case Unmarshaler: - return tUnmarshaler, nil - case Zeroer: - return tZeroer, nil - case Binary: - return tBinary, nil - case Undefined: - return tUndefined, nil - case ObjectID: - return tOID, nil - case DateTime: - return tDateTime, nil - case Null: - return tNull, nil - case Regex: - return tRegex, nil - case CodeWithScope: - return tCodeWithScope, nil - case DBPointer: - return tDBPointer, nil - case JavaScript: - return tJavaScript, nil - case Symbol: - return tSymbol, nil - case Timestamp: - return tTimestamp, nil - case Decimal128: - return tDecimal, nil - case Vector: - return tVector, nil - case MinKey: - return tMinKey, nil - case MaxKey: - return tMaxKey, nil - case D: - return tD, nil - case A: - return tA, nil - case E: - return tE, nil - case bsoncore.Document: - return tCoreDocument, nil - case bsoncore.Array: - return tCoreArray, nil - default: - return nil, fmt.Errorf("no default encoder for type %T", v) - } -} - -//func lookupReflectFreeEncoder(r *Registry, typ reflect.Type) (reflectFreeValueEncoder, error) { -//} - -func lookupEncoderReflectFree(r *Registry, typ reflect.Type, val any) (reflectFreeValueEncoder, error) { - if typ == nil { - var err error - - typ, err = getReflectTypeFromAny(val) - if err != nil { - return nil, err - } - } - - rfeEnc, found := r.reflectFreeTypeEncoders.Load(typ) - if !found { - return nil, errNoEncoder{Type: typ} - } - - return rfeEnc, nil -} - -func lookupUserDefinedEncoder(r *Registry, val any) (ValueEncoder, reflect.Type, bool, error) { - typ, err := getReflectTypeFromAny(val) - if err != nil { - return nil, nil, false, err - } - - enc, found := r.lookupTypeEncoder(typ) - if found { - if enc == nil { - return nil, typ, false, errNoEncoder{Type: typ} - } - - // We do not use ValueEncoder in the default case, preferring a reflect-free - // solution. - if _, ok := enc.(defaultValueEncoderFunc); ok { - return nil, typ, false, nil - } - - return enc, typ, true, nil - } - - return nil, typ, false, nil -} - // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup // order: // @@ -392,30 +263,27 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { return nil, errNoEncoder{Type: valueType} } - // First attempt to lookup a reflect-free default encoder. - // TODO: This will be moved in favor of the lookup* solution. - rfeEnc, found := r.reflectFreeTypeEncoders.Load(valueType) - if found { - if rfeEnc != nil { - wrapper := func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return rfeEnc.EncodeValue(ec, vw, val.Interface()) - } + // First attempt to get a user-defined type encoder. + if enc, found := r.lookupTypeEncoder(valueType); found { + if enc == nil { + return nil, errNoEncoder{Type: valueType} + } - return ValueEncoderFunc(wrapper), nil + if _, ok := enc.(defaultValueEncoderFunc); !ok { + return enc, nil } } - // Then lookup a user-defined encoder. - enc, found := r.lookupTypeEncoder(valueType) - if found { - if enc == nil { - return nil, errNoEncoder{Type: valueType} + // Next try to get a reflection-free encoder. + if rfeEnc, found := r.reflectFreeTypeEncoders.Load(valueType); found && rfeEnc != nil { + wrapper := func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return rfeEnc.EncodeValue(ec, vw, val.Interface()) } - return enc, nil + + return ValueEncoderFunc(wrapper), nil } - enc, found = r.lookupInterfaceEncoder(valueType, true) - if found { + if enc, found := r.lookupInterfaceEncoder(valueType, true); found { return r.typeEncoders.LoadOrStore(valueType, enc), nil } From b3fbcef9f1ef6e695eb8facd87b7d39fed128051 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 10 Apr 2025 15:32:57 -0600 Subject: [PATCH 03/13] GODRIVER-3455 Add RF suffix to encoders --- bson/default_value_encoders.go | 209 +++++++++++---------------------- 1 file changed, 70 insertions(+), 139 deletions(-) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 019e23682c..bee381576a 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -65,17 +65,17 @@ func registerDefaultEncoders(reg *Registry) { reg.registerReflectFreeTypeEncoder(tTime, reflectFreeValueEncoderFunc(timeEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tEmpty, reflectFreeValueEncoderFunc(emptyInterfaceValueRF)) reg.registerReflectFreeTypeEncoder(tCoreArray, reflectFreeValueEncoderFunc(coreArrayEncodeValueRF)) - reg.registerReflectFreeTypeEncoder(tNull, reflectFreeValueEncoderFunc(nullEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tOID, reflectFreeValueEncoderFunc(objectIDEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tDecimal, reflectFreeValueEncoderFunc(decimal128EncodeValueX)) - reg.registerReflectFreeTypeEncoder(tJSONNumber, reflectFreeValueEncoderFunc(jsonNumberEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tURL, reflectFreeValueEncoderFunc(urlEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tJavaScript, reflectFreeValueEncoderFunc(javaScriptEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tSymbol, reflectFreeValueEncoderFunc(symbolEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tBinary, reflectFreeValueEncoderFunc(binaryEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tVector, reflectFreeValueEncoderFunc(vectorEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tUndefined, reflectFreeValueEncoderFunc(undefinedEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tDateTime, reflectFreeValueEncoderFunc(dateTimeEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tNull, reflectFreeValueEncoderFunc(nullEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tOID, reflectFreeValueEncoderFunc(objectIDEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDecimal, reflectFreeValueEncoderFunc(decimal128EncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tJSONNumber, reflectFreeValueEncoderFunc(jsonNumberEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tURL, reflectFreeValueEncoderFunc(urlEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tJavaScript, reflectFreeValueEncoderFunc(javaScriptEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tSymbol, reflectFreeValueEncoderFunc(symbolEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tBinary, reflectFreeValueEncoderFunc(binaryEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tVector, reflectFreeValueEncoderFunc(vectorEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tUndefined, reflectFreeValueEncoderFunc(undefinedEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDateTime, reflectFreeValueEncoderFunc(dateTimeEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tRegex, reflectFreeValueEncoderFunc(regexEncodeValueX)) reg.registerReflectFreeTypeEncoder(tDBPointer, reflectFreeValueEncoderFunc(dbPointerEncodeValueX)) reg.registerReflectFreeTypeEncoder(tTimestamp, reflectFreeValueEncoderFunc(timestampEncodeValueX)) @@ -176,35 +176,25 @@ func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { } } -// floatEncodeValue is the ValueEncoderFunc for float types. -func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Float32, reflect.Float64: - return vw.WriteDouble(val.Float()) +func floatEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { + if f32, ok := val.(float32); ok { + return vw.WriteDouble(float64(f32)) } - return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} -} - -func floatEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { - switch val := val.(type) { - case float32, float64: - return vw.WriteDouble(val.(float64)) + if f64, ok := val.(float64); ok { + return vw.WriteDouble(f64) } return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: reflect.ValueOf(val)} } -// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tOID { - return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} - } - return vw.WriteObjectID(val.Interface().(ObjectID)) +// floatEncodeValue is the ValueEncoderFunc for float types. +func floatEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return floatEncodeValueRF(ec, vw, val.Interface()) } // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func objectIDEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { objID, ok := val.(ObjectID) if !ok { return ValueEncoderError{ @@ -217,48 +207,26 @@ func objectIDEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteObjectID(objID) } -// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDecimal { - return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} - } - return vw.WriteDecimal128(val.Interface().(Decimal128)) +// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. +func objectIDEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return objectIDEncodeValueRF(ec, vw, val.Interface()) } -func decimal128EncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func decimal128EncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { d128, ok := val.(Decimal128) if !ok { - return ValueEncoderError{ - Name: "Decimal128EncodeValue", - Types: []reflect.Type{tDecimal}, - Received: reflect.ValueOf(val), - } + return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: reflect.ValueOf(val)} } return vw.WriteDecimal128(d128) } -// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. -func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tJSONNumber { - return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} - } - jsnum := val.Interface().(json.Number) - - // Attempt int first, then float64 - if i64, err := jsnum.Int64(); err == nil { - return intEncodeValue(ec, vw, reflect.ValueOf(i64)) - } - - f64, err := jsnum.Float64() - if err != nil { - return err - } - - return floatEncodeValue(ec, vw, reflect.ValueOf(f64)) +// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. +func decimal128EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return decimal128EncodeValueRF(ec, vw, val.Interface()) } -func jsonNumberEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error { +func jsonNumberEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { jsnum, ok := val.(json.Number) if !ok { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: reflect.ValueOf(val)} @@ -274,19 +242,15 @@ func jsonNumberEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error { return err } - return floatEncodeValueX(ec, vw, f64) + return floatEncodeValueRF(ec, vw, f64) } -// urlEncodeValue is the ValueEncoderFunc for url.URL. -func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tURL { - return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} - } - u := val.Interface().(url.URL) - return vw.WriteString(u.String()) +// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. +func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return jsonNumberEncodeValueRF(ec, vw, val.Interface()) } -func urlEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func urlEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { u, ok := val.(url.URL) if !ok { return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: reflect.ValueOf(val)} @@ -295,6 +259,11 @@ func urlEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteString(u.String()) } +// urlEncodeValue is the ValueEncoderFunc for url.URL. +func urlEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return urlEncodeValueRF(ec, vw, val.Interface()) +} + // arrayEncodeValue is the ValueEncoderFunc for array types. func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { @@ -435,16 +404,7 @@ func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er return copyValueFromBytes(vw, TypeEmbeddedDocument, data) } -// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tJavaScript { - return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} - } - - return vw.WriteJavascript(val.String()) -} - -func javaScriptEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func javaScriptEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { jsString, ok := val.(JavaScript) if !ok { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: reflect.ValueOf(val)} @@ -453,16 +413,12 @@ func javaScriptEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteJavascript(string(jsString)) } -// symbolEncodeValue is the ValueEncoderFunc for the Symbol type. -func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tSymbol { - return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} - } - - return vw.WriteSymbol(val.String()) +// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. +func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + return javaScriptEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -func symbolEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func symbolEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { symbol, ok := val.(Symbol) if !ok { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: reflect.ValueOf(val)} @@ -471,17 +427,12 @@ func symbolEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteSymbol(string(symbol)) } -// binaryEncodeValue is the ValueEncoderFunc for Binary. -func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tBinary { - return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} - } - b := val.Interface().(Binary) - - return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +// symbolEncodeValue is the ValueEncoderFunc for the Symbol type. +func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + return symbolEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -func binaryEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func binaryEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { b, ok := val.(Binary) if !ok { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: reflect.ValueOf(val)} @@ -490,21 +441,12 @@ func binaryEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// vectorEncodeValue is the ValueEncoderFunc for Vector. -func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - t := val.Type() - if !val.IsValid() || t != tVector { - return ValueEncoderError{Name: "VectorEncodeValue", - Types: []reflect.Type{tVector}, - Received: val, - } - } - v := val.Interface().(Vector) - b := v.Binary() - return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +// binaryEncodeValue is the ValueEncoderFunc for Binary. +func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + return binaryEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -func vectorEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func vectorEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { v, ok := val.(Vector) if !ok { return ValueEncoderError{Name: "VectorEncodeValue", Types: []reflect.Type{tVector}, Received: reflect.ValueOf(val)} @@ -514,16 +456,12 @@ func vectorEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// undefinedEncodeValue is the ValueEncoderFunc for Undefined. -func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tUndefined { - return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} - } - - return vw.WriteUndefined() +// vectorEncodeValue is the ValueEncoderFunc for Vector. +func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + return vectorEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -func undefinedEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func undefinedEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(Undefined); !ok { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: reflect.ValueOf(val)} } @@ -531,16 +469,12 @@ func undefinedEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteUndefined() } -// dateTimeEncodeValue is the ValueEncoderFunc for DateTime. -func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDateTime { - return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} - } - - return vw.WriteDateTime(val.Int()) +// undefinedEncodeValue is the ValueEncoderFunc for Undefined. +func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + return undefinedEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -func dateTimeEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func dateTimeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { dateTime, ok := val.(DateTime) if !ok { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: reflect.ValueOf(val)} @@ -549,27 +483,24 @@ func dateTimeEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteDateTime(int64(dateTime)) } -// nullEncodeValue is the ValueEncoderFunc for Null. -func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tNull { - return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} - } - - return vw.WriteNull() +// dateTimeEncodeValue is the ValueEncoderFunc for DateTime. +func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + return dateTimeEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -func nullEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func nullEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(Null); !ok { - return ValueEncoderError{ - Name: "NullEncodeValue", - Types: []reflect.Type{tNull}, - Received: reflect.ValueOf(val), - } + return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: reflect.ValueOf(val)} } return vw.WriteNull() } +// nullEncodeValue is the ValueEncoderFunc for Null. +func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + return nullEncodeValueRF(EncodeContext{}, vw, val.Interface()) +} + // regexEncodeValue is the ValueEncoderFunc for Regex. func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { From 1cee2ee3688eda80fa3d0e3db97f80d822402cb4 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 10 Apr 2025 15:34:20 -0600 Subject: [PATCH 04/13] GODRIVER-3455 Remove kind RF cache --- bson/codec_cache.go | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/bson/codec_cache.go b/bson/codec_cache.go index 472d11ce23..4fffe0eff1 100644 --- a/bson/codec_cache.go +++ b/bson/codec_cache.go @@ -24,7 +24,6 @@ func init() { // statically assert array size var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer] var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer] -var _ = (kindEncoderReflectFreeCache{}).entries[reflect.UnsafePointer] type typeEncoderCache struct { cache sync.Map // map[reflect.Type]ValueEncoder @@ -129,39 +128,6 @@ func (c *typeDecoderCache) Clone() *typeDecoderCache { // so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type // is always the same (since different concrete types may implement the // ValueEncoder interface). -type kindEncoderReflectFreeCacheEntry struct { - enc reflectFreeValueEncoder -} - -type kindEncoderReflectFreeCache struct { - entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry -} - -func (c *kindEncoderReflectFreeCache) Store(rt reflect.Kind, enc reflectFreeValueEncoder) { - if enc != nil && rt < reflect.Kind(len(c.entries)) { - c.entries[rt].Store(&kindEncoderReflectFreeCacheEntry{enc: enc}) - } -} - -func (c *kindEncoderReflectFreeCache) Load(rt reflect.Kind) (reflectFreeValueEncoder, bool) { - if rt < reflect.Kind(len(c.entries)) { - if ent, ok := c.entries[rt].Load().(*kindEncoderReflectFreeCacheEntry); ok { - return ent.enc, ent.enc != nil - } - } - return nil, false -} - -func (c *kindEncoderReflectFreeCache) Clone() *kindEncoderReflectFreeCache { - cc := new(kindEncoderReflectFreeCache) - for i, v := range c.entries { - if val := v.Load(); val != nil { - cc.entries[i].Store(val) - } - } - return cc -} - type kindEncoderCacheEntry struct { enc ValueEncoder } From 4bb84425e461905f6c25ffab3b2103539f6ed327 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 10 Apr 2025 15:45:13 -0600 Subject: [PATCH 05/13] GODRIVER-3455 Cont. removing kind udpates --- bson/default_value_decoders_test.go | 6 -- bson/default_value_encoders.go | 11 +++- bson/default_value_encoders_test.go | 94 +++++++++++++++-------------- bson/registry.go | 15 ----- bson/registry_test.go | 1 - 5 files changed, 57 insertions(+), 70 deletions(-) diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 43379d1d14..269c287fce 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -3420,7 +3420,6 @@ func TestDefaultValueDecoders(t *testing.T) { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(topLevelReg) registerDefaultDecoders(topLevelReg) @@ -3433,7 +3432,6 @@ func TestDefaultValueDecoders(t *testing.T) { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(embeddedReg) registerDefaultDecoders(embeddedReg) @@ -3482,7 +3480,6 @@ func TestDefaultValueDecoders(t *testing.T) { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -3579,7 +3576,6 @@ func TestDefaultValueDecoders(t *testing.T) { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultDecoders(nestedRegistry) nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) @@ -3739,7 +3735,6 @@ func TestDefaultValueDecoders(t *testing.T) { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultDecoders(reg) reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) @@ -3816,7 +3811,6 @@ func buildDefaultRegistry() *Registry { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index bee381576a..5e1a631350 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -188,9 +188,16 @@ func floatEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: reflect.ValueOf(val)} } -// floatEncodeValue is the ValueEncoderFunc for float types. +// floatEncodeValue is the ValueEncoderFunc for float types. this function is +// used to decode "types" and "kinds" and therefore cannot be a wrapper for +// reflection-free decoding in the default "type" case. func floatEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return floatEncodeValueRF(ec, vw, val.Interface()) + switch val.Kind() { + case reflect.Float32, reflect.Float64: + return vw.WriteDouble(val.Float()) + } + + return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} } // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index d2eb0c364a..ac2bc6770a 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -1051,53 +1051,55 @@ func TestDefaultValueEncoders(t *testing.T) { }, }, }, - //{ - // "CoreArrayEncodeValue", - // &arrayCodec{}, - // []subtest{ - // { - // "wrong type", - // wrong, - // nil, - // nil, - // nothing, - // ValueEncoderError{ - // Name: "CoreArrayEncodeValue", - // Types: []reflect.Type{tCoreArray}, - // Received: reflect.ValueOf(wrong), - // }, - // }, + { + "CoreArrayEncodeValue", + ValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, v reflect.Value) error { + return coreArrayEncodeValueRF(ec, vw, v.Interface()) + }), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{ + Name: "CoreArrayEncodeValue", + Types: []reflect.Type{tCoreArray}, + Received: reflect.ValueOf(wrong), + }, + }, - // { - // "WriteArray Error", - // bsoncore.Array{}, - // nil, - // &valueReaderWriter{Err: errors.New("wa error"), ErrAfter: writeArray}, - // writeArray, - // errors.New("wa error"), - // }, - // { - // "WriteArrayElement Error", - // bsoncore.Array(buildDocumentArray(func([]byte) []byte { - // return bsoncore.AppendNullElement(nil, "foo") - // })), - // nil, - // &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, - // writeArrayElement, - // errors.New("wae error"), - // }, - // { - // "encodeValue error", - // bsoncore.Array(buildDocumentArray(func([]byte) []byte { - // return bsoncore.AppendNullElement(nil, "foo") - // })), - // nil, - // &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeNull}, - // writeNull, - // errors.New("ev error"), - // }, - // }, - //}, + { + "WriteArray Error", + bsoncore.Array{}, + nil, + &valueReaderWriter{Err: errors.New("wa error"), ErrAfter: writeArray}, + writeArray, + errors.New("wa error"), + }, + { + "WriteArrayElement Error", + bsoncore.Array(buildDocumentArray(func([]byte) []byte { + return bsoncore.AppendNullElement(nil, "foo") + })), + nil, + &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, + writeArrayElement, + errors.New("wae error"), + }, + { + "encodeValue error", + bsoncore.Array(buildDocumentArray(func([]byte) []byte { + return bsoncore.AppendNullElement(nil, "foo") + })), + nil, + &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeNull}, + writeNull, + errors.New("ev error"), + }, + }, + }, } for _, tc := range testCases { diff --git a/bson/registry.go b/bson/registry.go index 4db6f870b5..3bca9dcb8d 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -90,7 +90,6 @@ type Registry struct { typeMap sync.Map // map[Type]reflect.Type reflectFreeTypeEncoders *typeReflectFreeEncoderCache - reflectFreeKindEncoders *kindEncoderReflectFreeCache } // NewRegistry creates a new empty Registry. @@ -102,7 +101,6 @@ func NewRegistry() *Registry { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -128,10 +126,6 @@ func (r *Registry) registerReflectFreeTypeEncoder(valueType reflect.Type, enc re r.reflectFreeTypeEncoders.Store(valueType, enc) } -func (r *Registry) registerReflectFreeKindEncoder(kind reflect.Kind, enc reflectFreeValueEncoder) { - r.reflectFreeKindEncoders.Store(kind, enc) -} - func (r *Registry) storeReflectFreeTypeEncoder(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { return r.reflectFreeTypeEncoders.LoadOrStore(rt, enc) } @@ -287,15 +281,6 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { return r.typeEncoders.LoadOrStore(valueType, enc), nil } - if v, ok := r.reflectFreeKindEncoders.Load(valueType.Kind()); ok { - ve := r.storeReflectFreeTypeEncoder(valueType, v) - wrapper := func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return ve.EncodeValue(ec, vw, val.Interface()) - } - - return ValueEncoderFunc(wrapper), nil - } - if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { return r.storeTypeEncoder(valueType, v), nil } diff --git a/bson/registry_test.go b/bson/registry_test.go index 2d05c6a22a..f826ab5e78 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -24,7 +24,6 @@ func newTestRegistry() *Registry { kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), - reflectFreeKindEncoders: new(kindEncoderReflectFreeCache), } } From 6b3efad5b8d4181569462295d9cc1c3b34869ed1 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 10 Apr 2025 15:54:09 -0600 Subject: [PATCH 06/13] GODRIVER-3455 Fix uncommted default encoder tests --- bson/default_value_encoders.go | 15 ++++- bson/default_value_encoders_test.go | 99 ++++++++++++++--------------- 2 files changed, 61 insertions(+), 53 deletions(-) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 5e1a631350..378f9ae700 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -770,5 +770,18 @@ func emptyInterfaceValueRF(ec EncodeContext, vw ValueWriter, val any) error { } func emptyInterfaceValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return emptyInterfaceValueRF(ec, vw, val.Interface()) + if !val.IsValid() || val.Type() != tEmpty { + return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} + } + + if val.IsNil() { + return vw.WriteNull() + } + + encoder, err := ec.LookupEncoder(val.Elem().Type()) + if err != nil { + return err + } + + return encoder.EncodeValue(ec, vw, val.Elem()) } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index ac2bc6770a..1128cf4776 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -51,7 +51,7 @@ func TestDefaultValueEncoders(t *testing.T) { type myfloat32 float32 type myfloat64 float64 - //now := time.Now().Truncate(time.Millisecond) + now := time.Now().Truncate(time.Millisecond) pjsnum := new(json.Number) *pjsnum = json.Number("3.14159") d128 := NewDecimal128(12345, 67890) @@ -194,21 +194,21 @@ func TestDefaultValueEncoders(t *testing.T) { {"float64/reflection path", myfloat64(3.14159), nil, nil, writeDouble, nil}, }, }, - //{ - // "TimeEncodeValue", - // &timeCodec{}, - // []subtest{ - // { - // "wrong type", - // wrong, - // nil, - // nil, - // nothing, - // ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(wrong)}, - // }, - // {"time.Time", now, nil, nil, writeDateTime, nil}, - // }, - //}, + { + "TimeEncodeValue", + ValueEncoderFunc(timeEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(wrong)}, + }, + {"time.Time", now, nil, nil, writeDateTime, nil}, + }, + }, { "MapEncodeValue", &mapCodec{}, @@ -531,39 +531,36 @@ func TestDefaultValueEncoders(t *testing.T) { {"url.URL", url.URL{Scheme: "http", Host: "example.com"}, nil, nil, writeString, nil}, }, }, - //{ - // "ByteSliceEncodeValue", - // // TODO: Fix this. - // ValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - // return (&byteSliceCodec{}).EncodeValue(ec, vw, val) - // }), - // []subtest{ - // { - // "wrong type", - // wrong, - // nil, - // nil, - // nothing, - // ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: reflect.ValueOf(wrong)}, - // }, - // {"[]byte", []byte{0x01, 0x02, 0x03}, nil, nil, writeBinary, nil}, - // {"[]byte/nil", []byte(nil), nil, nil, writeNull, nil}, - // }, - //}, - //{ - // "EmptyInterfaceEncodeValue", - // &emptyInterfaceCodec{}, - // []subtest{ - // { - // "wrong type", - // wrong, - // nil, - // nil, - // nothing, - // ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(wrong)}, - // }, - // }, - //}, + { + "ByteSliceEncodeValue", + ValueEncoderFunc(byteSliceEncodeValue(false)), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: reflect.ValueOf(wrong)}, + }, + {"[]byte", []byte{0x01, 0x02, 0x03}, nil, nil, writeBinary, nil}, + {"[]byte/nil", []byte(nil), nil, nil, writeNull, nil}, + }, + }, + { + "EmptyInterfaceEncodeValue", + ValueEncoderFunc(emptyInterfaceValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(wrong)}, + }, + }, + }, { "ValueMarshalerEncodeValue", ValueEncoderFunc(valueMarshalerEncodeValue), @@ -1053,9 +1050,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreArrayEncodeValue", - ValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, v reflect.Value) error { - return coreArrayEncodeValueRF(ec, vw, v.Interface()) - }), + ValueEncoderFunc(coreArrayEncodeValue), []subtest{ { "wrong type", From fda086be575697196a499feeb72121dcfffd6ee9 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 10 Apr 2025 17:02:21 -0600 Subject: [PATCH 07/13] GODRIVER-3455 Extend encoder tests to include RF case --- bson/default_value_decoders_test.go | 54 +++++----- bson/default_value_encoders.go | 154 +++++++--------------------- bson/default_value_encoders_test.go | 84 +++++++++++++-- 3 files changed, 140 insertions(+), 152 deletions(-) diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 269c287fce..bb44a4f894 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -3414,11 +3414,10 @@ func TestDefaultValueDecoders(t *testing.T) { // the top-level to decode to registered type when unmarshalling to interface{} topLevelReg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(topLevelReg) @@ -3426,11 +3425,10 @@ func TestDefaultValueDecoders(t *testing.T) { topLevelReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) embeddedReg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(embeddedReg) @@ -3474,11 +3472,10 @@ func TestDefaultValueDecoders(t *testing.T) { // type information is not available. reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(reg) @@ -3570,11 +3567,10 @@ func TestDefaultValueDecoders(t *testing.T) { // Use a registry that has all default decoders with the custom interface{} decoder that always errors. nestedRegistry := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultDecoders(nestedRegistry) @@ -3729,11 +3725,10 @@ func TestDefaultValueDecoders(t *testing.T) { ) reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultDecoders(reg) @@ -3805,11 +3800,10 @@ func buildDocument(elems []byte) []byte { func buildDefaultRegistry() *Registry { reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), } registerDefaultEncoders(reg) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 378f9ae700..6ac772a486 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -63,7 +63,6 @@ func registerDefaultEncoders(reg *Registry) { // Register the reflect-free default type encoders. reg.registerReflectFreeTypeEncoder(tByteSlice, byteSliceEncodeValueRF(false)) reg.registerReflectFreeTypeEncoder(tTime, reflectFreeValueEncoderFunc(timeEncodeValueRF)) - reg.registerReflectFreeTypeEncoder(tEmpty, reflectFreeValueEncoderFunc(emptyInterfaceValueRF)) reg.registerReflectFreeTypeEncoder(tCoreArray, reflectFreeValueEncoderFunc(coreArrayEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tNull, reflectFreeValueEncoderFunc(nullEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tOID, reflectFreeValueEncoderFunc(objectIDEncodeValueRF)) @@ -76,13 +75,13 @@ func registerDefaultEncoders(reg *Registry) { reg.registerReflectFreeTypeEncoder(tVector, reflectFreeValueEncoderFunc(vectorEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tUndefined, reflectFreeValueEncoderFunc(undefinedEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tDateTime, reflectFreeValueEncoderFunc(dateTimeEncodeValueRF)) - reg.registerReflectFreeTypeEncoder(tRegex, reflectFreeValueEncoderFunc(regexEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tDBPointer, reflectFreeValueEncoderFunc(dbPointerEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tTimestamp, reflectFreeValueEncoderFunc(timestampEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tMinKey, reflectFreeValueEncoderFunc(minKeyEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tMaxKey, reflectFreeValueEncoderFunc(maxKeyEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tCoreDocument, reflectFreeValueEncoderFunc(coreDocumentEncodeValueX)) - reg.registerReflectFreeTypeEncoder(tCodeWithScope, reflectFreeValueEncoderFunc(codeWithScopeEncodeValueX)) + reg.registerReflectFreeTypeEncoder(tRegex, reflectFreeValueEncoderFunc(regexEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tDBPointer, reflectFreeValueEncoderFunc(dbPointerEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tTimestamp, reflectFreeValueEncoderFunc(timestampEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tMinKey, reflectFreeValueEncoderFunc(minKeyEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tMaxKey, reflectFreeValueEncoderFunc(maxKeyEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCoreDocument, reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF)) + reg.registerReflectFreeTypeEncoder(tCodeWithScope, reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF)) // Register the reflect-based default encoders. These are required since // removing them would break Registry.LookupEncoder. However, these will @@ -90,7 +89,7 @@ func registerDefaultEncoders(reg *Registry) { // reg.RegisterTypeEncoder(tByteSlice, byteSliceEncodeValue(false)) reg.RegisterTypeEncoder(tTime, defaultValueEncoderFunc(timeEncodeValue)) - reg.RegisterTypeEncoder(tEmpty, defaultValueEncoderFunc(emptyInterfaceValue)) // TODO: extend this to reflection free + reg.RegisterTypeEncoder(tEmpty, ValueEncoderFunc(emptyInterfaceValue)) reg.RegisterTypeEncoder(tCoreArray, defaultValueEncoderFunc(coreArrayEncodeValue)) reg.RegisterTypeEncoder(tOID, defaultValueEncoderFunc(objectIDEncodeValue)) reg.RegisterTypeEncoder(tDecimal, defaultValueEncoderFunc(decimal128EncodeValue)) @@ -508,18 +507,7 @@ func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { return nullEncodeValueRF(EncodeContext{}, vw, val.Interface()) } -// regexEncodeValue is the ValueEncoderFunc for Regex. -func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tRegex { - return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} - } - - regex := val.Interface().(Regex) - - return vw.WriteRegex(regex.Pattern, regex.Options) -} - -func regexEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func regexEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { regex, ok := val.(Regex) if !ok { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: reflect.ValueOf(val)} @@ -528,18 +516,12 @@ func regexEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteRegex(regex.Pattern, regex.Options) } -// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tDBPointer { - return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} - } - - dbp := val.Interface().(DBPointer) - - return vw.WriteDBPointer(dbp.DB, dbp.Pointer) +// regexEncodeValue is the ValueEncoderFunc for Regex. +func regexEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return regexEncodeValueRF(ec, vw, val.Interface()) } -func dbPointerEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func dbPointerEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { dbp, ok := val.(DBPointer) if !ok { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: reflect.ValueOf(val)} @@ -548,18 +530,12 @@ func dbPointerEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } -// timestampEncodeValue is the ValueEncoderFunc for Timestamp. -func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTimestamp { - return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} - } - - ts := val.Interface().(Timestamp) - - return vw.WriteTimestamp(ts.T, ts.I) +// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. +func dbPointerEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return dbPointerEncodeValueRF(ec, vw, val.Interface()) } -func timestampEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func timestampEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { ts, ok := val.(Timestamp) if !ok { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: reflect.ValueOf(val)} @@ -568,16 +544,12 @@ func timestampEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteTimestamp(ts.T, ts.I) } -// minKeyEncodeValue is the ValueEncoderFunc for MinKey. -func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tMinKey { - return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} - } - - return vw.WriteMinKey() +// timestampEncodeValue is the ValueEncoderFunc for Timestamp. +func timestampEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return timestampEncodeValueRF(ec, vw, val.Interface()) } -func minKeyEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func minKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(MinKey); !ok { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: reflect.ValueOf(val)} } @@ -585,16 +557,12 @@ func minKeyEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteMinKey() } -// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tMaxKey { - return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} - } - - return vw.WriteMaxKey() +// minKeyEncodeValue is the ValueEncoderFunc for MinKey. +func minKeyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return minKeyEncodeValueRF(ec, vw, val.Interface()) } -func maxKeyEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func maxKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(MaxKey); !ok { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: reflect.ValueOf(val)} } @@ -602,18 +570,12 @@ func maxKeyEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteMaxKey() } -// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCoreDocument { - return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} - } - - cdoc := val.Interface().(bsoncore.Document) - - return copyDocumentFromBytes(vw, cdoc) +// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. +func maxKeyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return maxKeyEncodeValueRF(ec, vw, val.Interface()) } -func coreDocumentEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { +func coreDocumentEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { cdoc, ok := val.(bsoncore.Document) if !ok { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: reflect.ValueOf(val)} @@ -622,45 +584,12 @@ func coreDocumentEncodeValueX(_ EncodeContext, vw ValueWriter, val any) error { return copyDocumentFromBytes(vw, cdoc) } -// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tCodeWithScope { - return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} - } - - cws := val.Interface().(CodeWithScope) - - dw, err := vw.WriteCodeWithScope(string(cws.Code)) - if err != nil { - return err - } - - sw := sliceWriterPool.Get().(*sliceWriter) - defer sliceWriterPool.Put(sw) - *sw = (*sw)[:0] - - scopeVW := bvwPool.Get().(*valueWriter) - scopeVW.reset(scopeVW.buf[:0]) - scopeVW.w = sw - defer bvwPool.Put(scopeVW) - encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) - if err != nil { - return err - } - - err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope)) - if err != nil { - return err - } - - err = copyBytesToDocumentWriter(dw, *sw) - if err != nil { - return err - } - return dw.WriteDocumentEnd() +// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. +func coreDocumentEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return coreDocumentEncodeValueRF(ec, vw, val.Interface()) } -func codeWithScopeEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error { +func codeWithScopeEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { cws, ok := val.(CodeWithScope) if !ok { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: reflect.ValueOf(val)} @@ -696,6 +625,11 @@ func codeWithScopeEncodeValueX(ec EncodeContext, vw ValueWriter, val any) error return dw.WriteDocumentEnd() } +// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. +func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { + return codeWithScopeEncodeValueRF(ec, vw, val.Interface()) +} + // isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type func isImplementationNil(val reflect.Value, inter reflect.Type) bool { vt := val.Type() @@ -757,18 +691,6 @@ func coreArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) e return coreArrayEncodeValueRF(ec, vw, val.Interface()) } -func emptyInterfaceValueRF(ec EncodeContext, vw ValueWriter, val any) error { - if val == nil { - return vw.WriteNull() - } - encoder, err := ec.LookupEncoder(reflect.TypeOf(val)) - if err != nil { - return err - } - - return encoder.EncodeValue(ec, vw, reflect.ValueOf(val)) -} - func emptyInterfaceValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 1128cf4776..a8cd829163 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -73,11 +73,13 @@ func TestDefaultValueEncoders(t *testing.T) { testCases := []struct { name string ve ValueEncoder + rfve reflectFreeValueEncoder subtests []subtest }{ { "BooleanEncodeValue", ValueEncoderFunc(booleanEncodeValue), + nil, []subtest{ { "wrong type", @@ -94,6 +96,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "IntEncodeValue", ValueEncoderFunc(intEncodeValue), + nil, []subtest{ { "wrong type", @@ -134,6 +137,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "UintEncodeValue", &uintCodec{}, + nil, []subtest{ { "wrong type", @@ -175,6 +179,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "FloatEncodeValue", ValueEncoderFunc(floatEncodeValue), + nil, []subtest{ { "wrong type", @@ -194,9 +199,34 @@ func TestDefaultValueEncoders(t *testing.T) { {"float64/reflection path", myfloat64(3.14159), nil, nil, writeDouble, nil}, }, }, + { + "reflection free FloatEncodeValue", + nil, + reflectFreeValueEncoderFunc(floatEncodeValueRF), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{ + Name: "FloatEncodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: reflect.ValueOf(wrong), + }, + }, + // the reflection free encoder function should only be used for + // encoding "types", not "kinds". So the reflection path tests are not + // valid. + {"float32/fast path", float32(3.14159), nil, nil, writeDouble, nil}, + {"float64/fast path", float64(3.14159), nil, nil, writeDouble, nil}, + }, + }, { "TimeEncodeValue", ValueEncoderFunc(timeEncodeValue), + reflectFreeValueEncoderFunc(timeEncodeValueRF), []subtest{ { "wrong type", @@ -212,6 +242,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MapEncodeValue", &mapCodec{}, + nil, []subtest{ { "wrong kind", @@ -292,6 +323,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ArrayEncodeValue", ValueEncoderFunc(arrayEncodeValue), + nil, []subtest{ { "wrong kind", @@ -370,6 +402,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "SliceEncodeValue", &sliceCodec{}, + nil, []subtest{ { "wrong kind", @@ -456,6 +489,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ObjectIDEncodeValue", ValueEncoderFunc(objectIDEncodeValue), + reflectFreeValueEncoderFunc(objectIDEncodeValueRF), []subtest{ { "wrong type", @@ -475,6 +509,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Decimal128EncodeValue", ValueEncoderFunc(decimal128EncodeValue), + reflectFreeValueEncoderFunc(decimal128EncodeValueRF), []subtest{ { "wrong type", @@ -490,6 +525,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "JSONNumberEncodeValue", ValueEncoderFunc(jsonNumberEncodeValue), + reflectFreeValueEncoderFunc(jsonNumberEncodeValueRF), []subtest{ { "wrong type", @@ -519,6 +555,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "URLEncodeValue", ValueEncoderFunc(urlEncodeValue), + reflectFreeValueEncoderFunc(urlEncodeValueRF), []subtest{ { "wrong type", @@ -534,6 +571,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ByteSliceEncodeValue", ValueEncoderFunc(byteSliceEncodeValue(false)), + reflectFreeValueEncoderFunc(byteSliceEncodeValueRF(false)), []subtest{ { "wrong type", @@ -550,6 +588,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "EmptyInterfaceEncodeValue", ValueEncoderFunc(emptyInterfaceValue), + nil, []subtest{ { "wrong type", @@ -564,6 +603,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ValueMarshalerEncodeValue", ValueEncoderFunc(valueMarshalerEncodeValue), + nil, []subtest{ { "wrong type", @@ -642,6 +682,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MarshalerEncodeValue", ValueEncoderFunc(marshalerEncodeValue), + nil, []subtest{ { "wrong type", @@ -704,6 +745,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "PointerCodec.EncodeValue", &pointerCodec{}, + nil, []subtest{ { "nil", @@ -742,6 +784,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "pointer implementation addressable interface", &pointerCodec{}, + nil, []subtest{ { "ValueMarshaler", @@ -764,6 +807,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "JavaScriptEncodeValue", ValueEncoderFunc(javaScriptEncodeValue), + reflectFreeValueEncoderFunc(javaScriptEncodeValueRF), []subtest{ { "wrong type", @@ -779,6 +823,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "SymbolEncodeValue", ValueEncoderFunc(symbolEncodeValue), + reflectFreeValueEncoderFunc(symbolEncodeValueRF), []subtest{ { "wrong type", @@ -794,6 +839,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "BinaryEncodeValue", ValueEncoderFunc(binaryEncodeValue), + reflectFreeValueEncoderFunc(binaryEncodeValueRF), []subtest{ { "wrong type", @@ -809,6 +855,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "UndefinedEncodeValue", ValueEncoderFunc(undefinedEncodeValue), + reflectFreeValueEncoderFunc(undefinedEncodeValueRF), []subtest{ { "wrong type", @@ -824,6 +871,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "DateTimeEncodeValue", ValueEncoderFunc(dateTimeEncodeValue), + reflectFreeValueEncoderFunc(dateTimeEncodeValueRF), []subtest{ { "wrong type", @@ -839,6 +887,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "NullEncodeValue", ValueEncoderFunc(nullEncodeValue), + reflectFreeValueEncoderFunc(nullEncodeValueRF), []subtest{ { "wrong type", @@ -854,6 +903,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "RegexEncodeValue", ValueEncoderFunc(regexEncodeValue), + reflectFreeValueEncoderFunc(regexEncodeValueRF), []subtest{ { "wrong type", @@ -869,6 +919,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "DBPointerEncodeValue", ValueEncoderFunc(dbPointerEncodeValue), + reflectFreeValueEncoderFunc(dbPointerEncodeValueRF), []subtest{ { "wrong type", @@ -891,6 +942,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "TimestampEncodeValue", ValueEncoderFunc(timestampEncodeValue), + reflectFreeValueEncoderFunc(timestampEncodeValueRF), []subtest{ { "wrong type", @@ -906,6 +958,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MinKeyEncodeValue", ValueEncoderFunc(minKeyEncodeValue), + reflectFreeValueEncoderFunc(minKeyEncodeValueRF), []subtest{ { "wrong type", @@ -921,6 +974,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "MaxKeyEncodeValue", ValueEncoderFunc(maxKeyEncodeValue), + reflectFreeValueEncoderFunc(maxKeyEncodeValueRF), []subtest{ { "wrong type", @@ -936,6 +990,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "CoreDocumentEncodeValue", ValueEncoderFunc(coreDocumentEncodeValue), + reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF), []subtest{ { "wrong type", @@ -994,6 +1049,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "StructEncodeValue", newStructCodec(&mapCodec{}), + nil, []subtest{ { "interface value", @@ -1016,6 +1072,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "CodeWithScopeEncodeValue", ValueEncoderFunc(codeWithScopeEncodeValue), + reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF), []subtest{ { "wrong type", @@ -1051,6 +1108,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "CoreArrayEncodeValue", ValueEncoderFunc(coreArrayEncodeValue), + reflectFreeValueEncoderFunc(coreArrayEncodeValueRF), []subtest{ { "wrong type", @@ -1110,13 +1168,27 @@ func TestDefaultValueEncoders(t *testing.T) { llvrw = subtest.llvrw } llvrw.T = t - err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) - if !assert.CompareErrors(err, subtest.err) { - t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + + if tc.ve != nil { + err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) + if !assert.CompareErrors(err, subtest.err) { + t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, subtest.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + } } - invoked := llvrw.invoked - if !cmp.Equal(invoked, subtest.invoke) { - t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + + if tc.rfve != nil { + err := tc.rfve.EncodeValue(ec, llvrw, subtest.val) + if !assert.CompareErrors(err, subtest.err) { + t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) + } + invoked := llvrw.invoked + if !cmp.Equal(invoked, subtest.invoke) { + t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke) + } } }) } From da49bdb6dbe047fbe96406f10b8a8a6c8f777c72 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 11 Apr 2025 11:02:59 -0600 Subject: [PATCH 08/13] GODRIVER-3455 Fix linting issues --- bson/byte_slice_codec.go | 6 +----- bson/default_value_encoders.go | 6 +++--- bson/default_value_encoders_test.go | 2 +- bson/registry.go | 4 ---- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index ce0ad976ec..d6d27fcc86 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -12,11 +12,7 @@ import ( ) // byteSliceCodec is the Codec used for []byte values. -type byteSliceCodec struct { - // encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values - // instead of BSON null. - encodeNilAsEmpty bool -} +type byteSliceCodec struct{} // Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be // used by collection type decoders (e.g. map, slice, etc) to set individual values in a diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 6ac772a486..ccd37e7413 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -190,7 +190,7 @@ func floatEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { // floatEncodeValue is the ValueEncoderFunc for float types. this function is // used to decode "types" and "kinds" and therefore cannot be a wrapper for // reflection-free decoding in the default "type" case. -func floatEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: return vw.WriteDouble(val.Float()) @@ -664,7 +664,7 @@ func byteSliceEncodeValue(encodeNilAsEmpty bool) defaultValueEncoderFunc { }) } -func timeEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { +func timeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { tt, ok := val.(time.Time) if !ok { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(val)} @@ -678,7 +678,7 @@ func timeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error return timeEncodeValueRF(ec, vw, val.Interface()) } -func coreArrayEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { +func coreArrayEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { arr, ok := val.(bsoncore.Array) if !ok { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: reflect.ValueOf(val)} diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index a8cd829163..70aae59f3f 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -571,7 +571,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ByteSliceEncodeValue", ValueEncoderFunc(byteSliceEncodeValue(false)), - reflectFreeValueEncoderFunc(byteSliceEncodeValueRF(false)), + byteSliceEncodeValueRF(false), []subtest{ { "wrong type", diff --git a/bson/registry.go b/bson/registry.go index 3bca9dcb8d..88c94ff828 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -126,10 +126,6 @@ func (r *Registry) registerReflectFreeTypeEncoder(valueType reflect.Type, enc re r.reflectFreeTypeEncoders.Store(valueType, enc) } -func (r *Registry) storeReflectFreeTypeEncoder(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { - return r.reflectFreeTypeEncoders.LoadOrStore(rt, enc) -} - // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. // // The type will be used as provided, so a decoder can be registered for a type and a different From a3e8c9c5cc19e720881a9395f890a71a6f4c1084 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 11 Apr 2025 11:41:52 -0600 Subject: [PATCH 09/13] GODRIVER-3455 Remove unused EncodeValue methods --- bson/array_codec.go | 10 ---------- bson/default_value_encoders_test.go | 8 ++++---- bson/empty_interface_codec.go | 22 ---------------------- bson/time_codec.go | 11 ----------- 4 files changed, 4 insertions(+), 47 deletions(-) diff --git a/bson/array_codec.go b/bson/array_codec.go index aa65803959..d235b69805 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -15,16 +15,6 @@ import ( // arrayCodec is the Codec used for bsoncore.Array values. type arrayCodec struct{} -// EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val any) error { - arr, ok := val.(bsoncore.Array) - if !ok { - return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: reflect.ValueOf(val)} - } - - return copyArrayFromBytes(vw, arr) -} - // DecodeValue is the ValueDecoder for bsoncore.Array values. func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreArray { diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 70aae59f3f..678450049f 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -1791,7 +1791,7 @@ func TestDefaultValueEncoders(t *testing.T) { t.Run("EmptyInterfaceEncodeValue/nil", func(t *testing.T) { val := reflect.New(tEmpty).Elem() llvrw := new(valueReaderWriter) - err := (&emptyInterfaceCodec{}).EncodeValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) + err := emptyInterfaceValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) noerr(t, err) if llvrw.invoked != writeNull { t.Errorf("Incorrect method called. got %v; want %v", llvrw.invoked, writeNull) @@ -1802,10 +1802,10 @@ func TestDefaultValueEncoders(t *testing.T) { val := reflect.New(tEmpty).Elem() val.Set(reflect.ValueOf(int64(1234567890))) llvrw := new(valueReaderWriter) - got := (&emptyInterfaceCodec{}).EncodeValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) + err := emptyInterfaceValue(EncodeContext{Registry: newTestRegistry()}, llvrw, val) want := errNoEncoder{Type: tInt64} - if !assert.CompareErrors(got, want) { - t.Errorf("Did not receive expected error. got %v; want %v", got, want) + if !assert.CompareErrors(err, want) { + t.Errorf("Did not receive expected error. got %v; want %v", err, want) } }) } diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 80d44d8c66..b669dfd8d9 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -17,28 +17,6 @@ type emptyInterfaceCodec struct { decodeBinaryAsSlice bool } -// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it -// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a -// collection. -var _ typeDecoder = &emptyInterfaceCodec{} - -// EncodeValue is the ValueEncoderFunc for interface{}. -func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tEmpty { - return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - encoder, err := ec.LookupEncoder(val.Elem().Type()) - if err != nil { - return err - } - - return encoder.EncodeValue(ec, vw, val.Elem()) -} - func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument if isDocument { diff --git a/bson/time_codec.go b/bson/time_codec.go index 85d37496fc..32be418a6b 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -97,14 +97,3 @@ func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V val.Set(elem) return nil } - -// EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val any) error { - timeVal, ok := val.(time.Time) - if !ok { - return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: reflect.ValueOf(val)} - } - - dt := NewDateTimeFromTime(timeVal) - return vw.WriteDateTime(int64(dt)) -} From 7a17cc6ec5e66b9894d278cd0c3809f05b23fc9b Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 11 Apr 2025 11:48:55 -0600 Subject: [PATCH 10/13] GODRIVER-3455 Add comments to reflectFreeValueEncoder logic --- bson/bsoncodec.go | 9 ++++++++- bson/default_value_encoders.go | 1 - 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 1a8ca03104..18eb30c142 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -140,7 +140,10 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref } // defaultValueEncoderFunc is an adapter function that allows a function with -// the correct signature to be used as a ValueEncoder. +// the correct signature to be used as a ValueEncoder. This differentiates +// between user-defined ValueEncoders and driver-defined ValueEncoders with the +// goal of forgoing drvier-defined behavior in favor of a reflection-free option +// if one exists. type defaultValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error // EncodeValue implements the ValueEncoder interface. @@ -148,12 +151,16 @@ func (fn defaultValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, return fn(ec, vw, val) } +// reflectFreeValueEncoder is a reflect-free version of ValueEncoder. type reflectFreeValueEncoder interface { EncodeValue(ec EncodeContext, vw ValueWriter, val any) error } +// reflectFreeValueEncoderFunc is an adapter function that allows a function +// with the correct signature to be used as a reflectFreeValueEncoder. type reflectFreeValueEncoderFunc func(ec EncodeContext, vw ValueWriter, val any) error +// EncodeValue implements the reflectFreeValueEncoder interface. func (fn reflectFreeValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val any) error { return fn(ec, vw, val) } diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index ccd37e7413..e7f5ed5bc2 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -86,7 +86,6 @@ func registerDefaultEncoders(reg *Registry) { // Register the reflect-based default encoders. These are required since // removing them would break Registry.LookupEncoder. However, these will // never be used internally. - // reg.RegisterTypeEncoder(tByteSlice, byteSliceEncodeValue(false)) reg.RegisterTypeEncoder(tTime, defaultValueEncoderFunc(timeEncodeValue)) reg.RegisterTypeEncoder(tEmpty, ValueEncoderFunc(emptyInterfaceValue)) From cf3a91e507a826b3dd9f23c23eeeb0fdab1bdbf8 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 11 Apr 2025 11:57:55 -0600 Subject: [PATCH 11/13] GODRIVER-3455 Remove reflect-based type encoders --- bson/codec_cache.go | 12 +-- bson/default_value_decoders_test.go | 12 +-- bson/default_value_encoders.go | 129 +--------------------------- bson/default_value_encoders_test.go | 40 ++++----- bson/registry.go | 4 +- bson/registry_test.go | 2 +- 6 files changed, 36 insertions(+), 163 deletions(-) diff --git a/bson/codec_cache.go b/bson/codec_cache.go index 4fffe0eff1..7bf8f7c419 100644 --- a/bson/codec_cache.go +++ b/bson/codec_cache.go @@ -58,30 +58,30 @@ func (c *typeEncoderCache) Clone() *typeEncoderCache { return cc } -type typeReflectFreeEncoderCache struct { +type reflectFreeTypeEncoderCache struct { cache sync.Map // map[reflect.Type]typeReflectFreeEncoderCache } -func (c *typeReflectFreeEncoderCache) Store(rt reflect.Type, enc reflectFreeValueEncoder) { +func (c *reflectFreeTypeEncoderCache) Store(rt reflect.Type, enc reflectFreeValueEncoder) { c.cache.Store(rt, enc) } -func (c *typeReflectFreeEncoderCache) Load(rt reflect.Type) (reflectFreeValueEncoder, bool) { +func (c *reflectFreeTypeEncoderCache) Load(rt reflect.Type) (reflectFreeValueEncoder, bool) { if v, _ := c.cache.Load(rt); v != nil { return v.(reflectFreeValueEncoder), true } return nil, false } -func (c *typeReflectFreeEncoderCache) LoadOrStore(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { +func (c *reflectFreeTypeEncoderCache) LoadOrStore(rt reflect.Type, enc reflectFreeValueEncoder) reflectFreeValueEncoder { if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { enc = v.(reflectFreeValueEncoder) } return enc } -func (c *typeReflectFreeEncoderCache) Clone() *typeReflectFreeEncoderCache { - cc := new(typeReflectFreeEncoderCache) +func (c *reflectFreeTypeEncoderCache) Clone() *reflectFreeTypeEncoderCache { + cc := new(reflectFreeTypeEncoderCache) c.cache.Range(func(k, v interface{}) bool { if k != nil && v != nil { cc.cache.Store(k, v) diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index bb44a4f894..7ea2e55efd 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -3418,7 +3418,7 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(topLevelReg) registerDefaultDecoders(topLevelReg) @@ -3429,7 +3429,7 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(embeddedReg) registerDefaultDecoders(embeddedReg) @@ -3476,7 +3476,7 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) @@ -3571,7 +3571,7 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultDecoders(nestedRegistry) nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) @@ -3729,7 +3729,7 @@ func TestDefaultValueDecoders(t *testing.T) { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultDecoders(reg) reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) @@ -3804,7 +3804,7 @@ func buildDefaultRegistry() *Registry { typeDecoders: new(typeDecoderCache), kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index e7f5ed5bc2..8319298a69 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -83,31 +83,8 @@ func registerDefaultEncoders(reg *Registry) { reg.registerReflectFreeTypeEncoder(tCoreDocument, reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF)) reg.registerReflectFreeTypeEncoder(tCodeWithScope, reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF)) - // Register the reflect-based default encoders. These are required since - // removing them would break Registry.LookupEncoder. However, these will - // never be used internally. - reg.RegisterTypeEncoder(tByteSlice, byteSliceEncodeValue(false)) - reg.RegisterTypeEncoder(tTime, defaultValueEncoderFunc(timeEncodeValue)) + // Register the reflect-based default encoders. reg.RegisterTypeEncoder(tEmpty, ValueEncoderFunc(emptyInterfaceValue)) - reg.RegisterTypeEncoder(tCoreArray, defaultValueEncoderFunc(coreArrayEncodeValue)) - reg.RegisterTypeEncoder(tOID, defaultValueEncoderFunc(objectIDEncodeValue)) - reg.RegisterTypeEncoder(tDecimal, defaultValueEncoderFunc(decimal128EncodeValue)) - reg.RegisterTypeEncoder(tJSONNumber, defaultValueEncoderFunc(jsonNumberEncodeValue)) - reg.RegisterTypeEncoder(tURL, defaultValueEncoderFunc(urlEncodeValue)) - reg.RegisterTypeEncoder(tJavaScript, defaultValueEncoderFunc(javaScriptEncodeValue)) - reg.RegisterTypeEncoder(tSymbol, defaultValueEncoderFunc(symbolEncodeValue)) - reg.RegisterTypeEncoder(tBinary, defaultValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tVector, defaultValueEncoderFunc(vectorEncodeValue)) - reg.RegisterTypeEncoder(tUndefined, defaultValueEncoderFunc(undefinedEncodeValue)) - reg.RegisterTypeEncoder(tDateTime, defaultValueEncoderFunc(dateTimeEncodeValue)) - reg.RegisterTypeEncoder(tNull, defaultValueEncoderFunc(nullEncodeValue)) - reg.RegisterTypeEncoder(tRegex, defaultValueEncoderFunc(regexEncodeValue)) - reg.RegisterTypeEncoder(tDBPointer, defaultValueEncoderFunc(dbPointerEncodeValue)) - reg.RegisterTypeEncoder(tTimestamp, defaultValueEncoderFunc(timestampEncodeValue)) - reg.RegisterTypeEncoder(tMinKey, defaultValueEncoderFunc(minKeyEncodeValue)) - reg.RegisterTypeEncoder(tMaxKey, defaultValueEncoderFunc(maxKeyEncodeValue)) - reg.RegisterTypeEncoder(tCoreDocument, defaultValueEncoderFunc(coreDocumentEncodeValue)) - reg.RegisterTypeEncoder(tCodeWithScope, defaultValueEncoderFunc(codeWithScopeEncodeValue)) // Register the kind-based default encoders. These must continue using // reflection since they account for custom types that cannot be anticipated. @@ -212,11 +189,6 @@ func objectIDEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteObjectID(objID) } -// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return objectIDEncodeValueRF(ec, vw, val.Interface()) -} - func decimal128EncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { d128, ok := val.(Decimal128) if !ok { @@ -226,11 +198,6 @@ func decimal128EncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteDecimal128(d128) } -// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -func decimal128EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return decimal128EncodeValueRF(ec, vw, val.Interface()) -} - func jsonNumberEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { jsnum, ok := val.(json.Number) if !ok { @@ -250,11 +217,6 @@ func jsonNumberEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { return floatEncodeValueRF(ec, vw, f64) } -// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. -func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return jsonNumberEncodeValueRF(ec, vw, val.Interface()) -} - func urlEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { u, ok := val.(url.URL) if !ok { @@ -264,11 +226,6 @@ func urlEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteString(u.String()) } -// urlEncodeValue is the ValueEncoderFunc for url.URL. -func urlEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return urlEncodeValueRF(ec, vw, val.Interface()) -} - // arrayEncodeValue is the ValueEncoderFunc for array types. func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { @@ -418,11 +375,6 @@ func javaScriptEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteJavascript(string(jsString)) } -// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - return javaScriptEncodeValueRF(EncodeContext{}, vw, val.Interface()) -} - func symbolEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { symbol, ok := val.(Symbol) if !ok { @@ -432,11 +384,6 @@ func symbolEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteSymbol(string(symbol)) } -// symbolEncodeValue is the ValueEncoderFunc for the Symbol type. -func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - return symbolEncodeValueRF(EncodeContext{}, vw, val.Interface()) -} - func binaryEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { b, ok := val.(Binary) if !ok { @@ -446,11 +393,6 @@ func binaryEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// binaryEncodeValue is the ValueEncoderFunc for Binary. -func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - return binaryEncodeValueRF(EncodeContext{}, vw, val.Interface()) -} - func vectorEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { v, ok := val.(Vector) if !ok { @@ -461,11 +403,6 @@ func vectorEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// vectorEncodeValue is the ValueEncoderFunc for Vector. -func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - return vectorEncodeValueRF(EncodeContext{}, vw, val.Interface()) -} - func undefinedEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(Undefined); !ok { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: reflect.ValueOf(val)} @@ -474,11 +411,6 @@ func undefinedEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteUndefined() } -// undefinedEncodeValue is the ValueEncoderFunc for Undefined. -func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - return undefinedEncodeValueRF(EncodeContext{}, vw, val.Interface()) -} - func dateTimeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { dateTime, ok := val.(DateTime) if !ok { @@ -488,11 +420,6 @@ func dateTimeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteDateTime(int64(dateTime)) } -// dateTimeEncodeValue is the ValueEncoderFunc for DateTime. -func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - return dateTimeEncodeValueRF(EncodeContext{}, vw, val.Interface()) -} - func nullEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(Null); !ok { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: reflect.ValueOf(val)} @@ -501,11 +428,6 @@ func nullEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteNull() } -// nullEncodeValue is the ValueEncoderFunc for Null. -func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - return nullEncodeValueRF(EncodeContext{}, vw, val.Interface()) -} - func regexEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { regex, ok := val.(Regex) if !ok { @@ -515,11 +437,6 @@ func regexEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteRegex(regex.Pattern, regex.Options) } -// regexEncodeValue is the ValueEncoderFunc for Regex. -func regexEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return regexEncodeValueRF(ec, vw, val.Interface()) -} - func dbPointerEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { dbp, ok := val.(DBPointer) if !ok { @@ -529,11 +446,6 @@ func dbPointerEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } -// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func dbPointerEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return dbPointerEncodeValueRF(ec, vw, val.Interface()) -} - func timestampEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { ts, ok := val.(Timestamp) if !ok { @@ -543,11 +455,6 @@ func timestampEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteTimestamp(ts.T, ts.I) } -// timestampEncodeValue is the ValueEncoderFunc for Timestamp. -func timestampEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return timestampEncodeValueRF(ec, vw, val.Interface()) -} - func minKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(MinKey); !ok { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: reflect.ValueOf(val)} @@ -556,11 +463,6 @@ func minKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteMinKey() } -// minKeyEncodeValue is the ValueEncoderFunc for MinKey. -func minKeyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return minKeyEncodeValueRF(ec, vw, val.Interface()) -} - func maxKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { if _, ok := val.(MaxKey); !ok { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: reflect.ValueOf(val)} @@ -569,11 +471,6 @@ func maxKeyEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteMaxKey() } -// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func maxKeyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return maxKeyEncodeValueRF(ec, vw, val.Interface()) -} - func coreDocumentEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { cdoc, ok := val.(bsoncore.Document) if !ok { @@ -583,11 +480,6 @@ func coreDocumentEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return copyDocumentFromBytes(vw, cdoc) } -// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func coreDocumentEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return coreDocumentEncodeValueRF(ec, vw, val.Interface()) -} - func codeWithScopeEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error { cws, ok := val.(CodeWithScope) if !ok { @@ -624,11 +516,6 @@ func codeWithScopeEncodeValueRF(ec EncodeContext, vw ValueWriter, val any) error return dw.WriteDocumentEnd() } -// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return codeWithScopeEncodeValueRF(ec, vw, val.Interface()) -} - // isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type func isImplementationNil(val reflect.Value, inter reflect.Type) bool { vt := val.Type() @@ -657,12 +544,6 @@ func byteSliceEncodeValueRF(encodeNilAsEmpty bool) reflectFreeValueEncoderFunc { }) } -func byteSliceEncodeValue(encodeNilAsEmpty bool) defaultValueEncoderFunc { - return defaultValueEncoderFunc(func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return byteSliceEncodeValueRF(encodeNilAsEmpty)(ec, vw, val.Interface()) - }) -} - func timeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { tt, ok := val.(time.Time) if !ok { @@ -673,10 +554,6 @@ func timeEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return vw.WriteDateTime(int64(dt)) } -func timeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return timeEncodeValueRF(ec, vw, val.Interface()) -} - func coreArrayEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { arr, ok := val.(bsoncore.Array) if !ok { @@ -686,10 +563,6 @@ func coreArrayEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { return copyArrayFromBytes(vw, arr) } -func coreArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return coreArrayEncodeValueRF(ec, vw, val.Interface()) -} - func emptyInterfaceValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 678450049f..fa01fdbacb 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -225,7 +225,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimeEncodeValue", - ValueEncoderFunc(timeEncodeValue), + nil, reflectFreeValueEncoderFunc(timeEncodeValueRF), []subtest{ { @@ -488,7 +488,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ObjectIDEncodeValue", - ValueEncoderFunc(objectIDEncodeValue), + nil, reflectFreeValueEncoderFunc(objectIDEncodeValueRF), []subtest{ { @@ -508,7 +508,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "Decimal128EncodeValue", - ValueEncoderFunc(decimal128EncodeValue), + nil, reflectFreeValueEncoderFunc(decimal128EncodeValueRF), []subtest{ { @@ -524,7 +524,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JSONNumberEncodeValue", - ValueEncoderFunc(jsonNumberEncodeValue), + nil, reflectFreeValueEncoderFunc(jsonNumberEncodeValueRF), []subtest{ { @@ -554,7 +554,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "URLEncodeValue", - ValueEncoderFunc(urlEncodeValue), + nil, reflectFreeValueEncoderFunc(urlEncodeValueRF), []subtest{ { @@ -570,7 +570,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ByteSliceEncodeValue", - ValueEncoderFunc(byteSliceEncodeValue(false)), + nil, byteSliceEncodeValueRF(false), []subtest{ { @@ -806,7 +806,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JavaScriptEncodeValue", - ValueEncoderFunc(javaScriptEncodeValue), + nil, reflectFreeValueEncoderFunc(javaScriptEncodeValueRF), []subtest{ { @@ -822,7 +822,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SymbolEncodeValue", - ValueEncoderFunc(symbolEncodeValue), + nil, reflectFreeValueEncoderFunc(symbolEncodeValueRF), []subtest{ { @@ -838,7 +838,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "BinaryEncodeValue", - ValueEncoderFunc(binaryEncodeValue), + nil, reflectFreeValueEncoderFunc(binaryEncodeValueRF), []subtest{ { @@ -854,7 +854,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UndefinedEncodeValue", - ValueEncoderFunc(undefinedEncodeValue), + nil, reflectFreeValueEncoderFunc(undefinedEncodeValueRF), []subtest{ { @@ -870,7 +870,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DateTimeEncodeValue", - ValueEncoderFunc(dateTimeEncodeValue), + nil, reflectFreeValueEncoderFunc(dateTimeEncodeValueRF), []subtest{ { @@ -886,7 +886,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "NullEncodeValue", - ValueEncoderFunc(nullEncodeValue), + nil, reflectFreeValueEncoderFunc(nullEncodeValueRF), []subtest{ { @@ -902,7 +902,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "RegexEncodeValue", - ValueEncoderFunc(regexEncodeValue), + nil, reflectFreeValueEncoderFunc(regexEncodeValueRF), []subtest{ { @@ -918,7 +918,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DBPointerEncodeValue", - ValueEncoderFunc(dbPointerEncodeValue), + nil, reflectFreeValueEncoderFunc(dbPointerEncodeValueRF), []subtest{ { @@ -941,7 +941,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimestampEncodeValue", - ValueEncoderFunc(timestampEncodeValue), + nil, reflectFreeValueEncoderFunc(timestampEncodeValueRF), []subtest{ { @@ -957,7 +957,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MinKeyEncodeValue", - ValueEncoderFunc(minKeyEncodeValue), + nil, reflectFreeValueEncoderFunc(minKeyEncodeValueRF), []subtest{ { @@ -973,7 +973,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MaxKeyEncodeValue", - ValueEncoderFunc(maxKeyEncodeValue), + nil, reflectFreeValueEncoderFunc(maxKeyEncodeValueRF), []subtest{ { @@ -989,7 +989,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreDocumentEncodeValue", - ValueEncoderFunc(coreDocumentEncodeValue), + nil, reflectFreeValueEncoderFunc(coreDocumentEncodeValueRF), []subtest{ { @@ -1071,7 +1071,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CodeWithScopeEncodeValue", - ValueEncoderFunc(codeWithScopeEncodeValue), + nil, reflectFreeValueEncoderFunc(codeWithScopeEncodeValueRF), []subtest{ { @@ -1107,7 +1107,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreArrayEncodeValue", - ValueEncoderFunc(coreArrayEncodeValue), + nil, reflectFreeValueEncoderFunc(coreArrayEncodeValueRF), []subtest{ { diff --git a/bson/registry.go b/bson/registry.go index 88c94ff828..45b28c2d67 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -89,7 +89,7 @@ type Registry struct { kindDecoders *kindDecoderCache typeMap sync.Map // map[Type]reflect.Type - reflectFreeTypeEncoders *typeReflectFreeEncoderCache + reflectFreeTypeEncoders *reflectFreeTypeEncoderCache } // NewRegistry creates a new empty Registry. @@ -100,7 +100,7 @@ func NewRegistry() *Registry { kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } registerDefaultEncoders(reg) registerDefaultDecoders(reg) diff --git a/bson/registry_test.go b/bson/registry_test.go index f826ab5e78..fd603cd66e 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -23,7 +23,7 @@ func newTestRegistry() *Registry { kindEncoders: new(kindEncoderCache), kindDecoders: new(kindDecoderCache), - reflectFreeTypeEncoders: new(typeReflectFreeEncoderCache), + reflectFreeTypeEncoders: new(reflectFreeTypeEncoderCache), } } From b65e94500709a112b551be934293d18a45b8a027 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 11 Apr 2025 11:59:43 -0600 Subject: [PATCH 12/13] GODRIVER-3455 Remove defaultValueEncoderFunc --- bson/bsoncodec.go | 12 ------------ bson/registry.go | 4 +--- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 18eb30c142..276b02f80c 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -139,18 +139,6 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref return fn(ec, vw, val) } -// defaultValueEncoderFunc is an adapter function that allows a function with -// the correct signature to be used as a ValueEncoder. This differentiates -// between user-defined ValueEncoders and driver-defined ValueEncoders with the -// goal of forgoing drvier-defined behavior in favor of a reflection-free option -// if one exists. -type defaultValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error - -// EncodeValue implements the ValueEncoder interface. -func (fn defaultValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return fn(ec, vw, val) -} - // reflectFreeValueEncoder is a reflect-free version of ValueEncoder. type reflectFreeValueEncoder interface { EncodeValue(ec EncodeContext, vw ValueWriter, val any) error diff --git a/bson/registry.go b/bson/registry.go index 45b28c2d67..dbe4595e10 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -259,9 +259,7 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { return nil, errNoEncoder{Type: valueType} } - if _, ok := enc.(defaultValueEncoderFunc); !ok { - return enc, nil - } + return enc, nil } // Next try to get a reflection-free encoder. From fdc76b5e3e406dce1dd6c9b6ac671590823e3ebf Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 11 Apr 2025 12:06:11 -0600 Subject: [PATCH 13/13] GODRIVER-3455 Touch ups --- bson/default_value_encoders.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 8319298a69..881acfe73d 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -179,11 +179,7 @@ func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error func objectIDEncodeValueRF(_ EncodeContext, vw ValueWriter, val any) error { objID, ok := val.(ObjectID) if !ok { - return ValueEncoderError{ - Name: "ObjectIDEncodeValue", - Types: []reflect.Type{tOID}, - Received: reflect.ValueOf(val), - } + return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: reflect.ValueOf(val)} } return vw.WriteObjectID(objID)