Skip to content

Commit

Permalink
GODRIVER-2887 Remove use of reflect.Value.MethodByName in bson (#1308)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlievieth authored and qingyang-hu committed Aug 1, 2023
1 parent 9318bc2 commit 8857a04
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 49 deletions.
20 changes: 10 additions & 10 deletions bson/bsoncodec/default_value_decoders.go
Expand Up @@ -1540,12 +1540,12 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr
return err
}

fn := val.Convert(tValueUnmarshaler).MethodByName("UnmarshalBSONValue")
errVal := fn.Call([]reflect.Value{reflect.ValueOf(t), reflect.ValueOf(src)})[0]
if !errVal.IsNil() {
return errVal.Interface().(error)
m, ok := val.Interface().(ValueUnmarshaler)
if !ok {
// NB: this error should be unreachable due to the above checks
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}
return nil
return m.UnmarshalBSONValue(t, src)
}

// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations.
Expand Down Expand Up @@ -1588,12 +1588,12 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonr
val = val.Addr() // If the type doesn't implement the interface, a pointer to it must.
}

fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON")
errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0]
if !errVal.IsNil() {
return errVal.Interface().(error)
m, ok := val.Interface().(Unmarshaler)
if !ok {
// NB: this error should be unreachable due to the above checks
return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val}
}
return nil
return m.UnmarshalBSON(src)
}

// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}.
Expand Down
11 changes: 10 additions & 1 deletion bson/bsoncodec/default_value_decoders_test.go
Expand Up @@ -1530,13 +1530,22 @@ func TestDefaultValueDecoders(t *testing.T) {
errors.New("copy error"),
},
{
"Unmarshaler",
// Only the pointer form of testUnmarshaler implements Unmarshaler
"value does not implement Unmarshaler",
testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)},
nil,
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
bsonrwtest.ReadDouble,
nil,
},
{
"Unmarshaler",
&testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)},
nil,
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
bsonrwtest.ReadDouble,
nil,
},
},
},
{
Expand Down
56 changes: 34 additions & 22 deletions bson/bsoncodec/default_value_encoders.go
Expand Up @@ -564,12 +564,14 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw bs
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}

fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue")
returns := fn.Call(nil)
if !returns[2].IsNil() {
return returns[2].Interface().(error)
m, ok := val.Interface().(ValueMarshaler)
if !ok {
return vw.WriteNull()
}
t, data, err := m.MarshalBSONValue()
if err != nil {
return err
}
t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data)
}

Expand All @@ -593,12 +595,14 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw bsonrw.
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}

fn := val.Convert(tMarshaler).MethodByName("MarshalBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
m, ok := val.Interface().(Marshaler)
if !ok {
return vw.WriteNull()
}
data, err := m.MarshalBSON()
if err != nil {
return err
}
data := returns[0].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data)
}

Expand All @@ -622,23 +626,31 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.Val
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
}

fn := val.Convert(tProxy).MethodByName("ProxyBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
m, ok := val.Interface().(Proxy)
if !ok {
return vw.WriteNull()
}
v, err := m.ProxyBSON()
if err != nil {
return err
}
if v == nil {
encoder, err := ec.LookupEncoder(nil)
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil))
}
data := returns[0]
var encoder ValueEncoder
var err error
if data.Elem().IsValid() {
encoder, err = ec.LookupEncoder(data.Elem().Type())
} else {
encoder, err = ec.LookupEncoder(nil)
vv := reflect.ValueOf(v)
switch vv.Kind() {
case reflect.Ptr, reflect.Interface:
vv = vv.Elem()
}
encoder, err := ec.LookupEncoder(vv.Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, data.Elem())
return encoder.EncodeValue(ec, vw, vv)
}

// JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type.
Expand Down
38 changes: 22 additions & 16 deletions bson/mgocompat/setter_getter.go
Expand Up @@ -7,6 +7,7 @@
package mgocompat

import (
"errors"
"reflect"

"go.mongodb.org/mongo-driver/bson"
Expand Down Expand Up @@ -73,16 +74,15 @@ func SetterDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val ref
return err
}

fn := val.Convert(tSetter).MethodByName("SetBSON")

errVal := fn.Call([]reflect.Value{reflect.ValueOf(bson.RawValue{Type: t, Value: src})})[0]
if !errVal.IsNil() {
err = errVal.Interface().(error)
if err == ErrSetZero {
val.Set(reflect.Zero(val.Type()))
return nil
m, ok := val.Interface().(Setter)
if !ok {
return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil {
if !errors.Is(err, ErrSetZero) {
return err
}
return err
val.Set(reflect.Zero(val.Type()))
}
return nil
}
Expand All @@ -104,17 +104,23 @@ func GetterEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val re
return bsoncodec.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
}

fn := val.Convert(tGetter).MethodByName("GetBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
m, ok := val.Interface().(Getter)
if !ok {
return vw.WriteNull()
}
x, err := m.GetBSON()
if err != nil {
return err
}
if x == nil {
return vw.WriteNull()
}
intermediate := returns[0]
encoder, err := ec.Registry.LookupEncoder(intermediate.Type())
vv := reflect.ValueOf(x)
encoder, err := ec.Registry.LookupEncoder(vv.Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, intermediate)
return encoder.EncodeValue(ec, vw, vv)
}

// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
Expand Down

0 comments on commit 8857a04

Please sign in to comment.