Skip to content

Commit

Permalink
map/multi binding: properly handle scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
bastianccm committed Feb 4, 2020
1 parent 975c0e7 commit de4bf3b
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 36 deletions.
65 changes: 29 additions & 36 deletions dingo.go
Expand Up @@ -214,13 +214,15 @@ func (injector *Injector) getInstance(typ interface{}, annotatedWith string, cir
}
}

return injector.resolveType(oftype, annotatedWith, false, circularTrace)
return injector.getInstanceOfTypeWithAnnotation(oftype, annotatedWith, nil, false, circularTrace)
}

func (injector *Injector) findBinding(t reflect.Type, annotation string) *Binding {
func (injector *Injector) findBindingForAnnotatedType(t reflect.Type, annotation string) *Binding {
if len(injector.bindings[t]) > 0 {
if binding := injector.lookupBinding(t, annotation); binding != nil {
return binding
for _, binding := range injector.bindings[t] {
if binding.annotatedWith == annotation {
return binding
}
}
}

Expand All @@ -231,27 +233,29 @@ func (injector *Injector) findBinding(t reflect.Type, annotation string) *Bindin

// ask parent
if injector.parent != nil {
return injector.parent.findBinding(t, annotation)
return injector.parent.findBindingForAnnotatedType(t, annotation)
}

return nil
}

// resolveType resolves a requested type, with annotation
func (injector *Injector) resolveType(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) {
// getInstanceOfTypeWithAnnotation resolves a requested type, with annotation
func (injector *Injector) getInstanceOfTypeWithAnnotation(t reflect.Type, annotation string, binding *Binding, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}

var final reflect.Value
var err error

if binding := injector.findBinding(t, annotation); binding != nil {
if typeBinding := injector.findBindingForAnnotatedType(t, annotation); typeBinding != nil {
binding = typeBinding
}
if binding != nil {
if binding.scope != nil {
if scope, ok := injector.scopes[reflect.TypeOf(binding.scope)]; ok {
//final = scope.ResolveType(t, annotation, injector.internalResolveType)
if final, err = scope.ResolveType(t, annotation, func(t reflect.Type, annotation string, optional bool) (reflect.Value, error) {
return injector.internalResolveType(t, annotation, optional, circularTrace)
return injector.createInstanceOfAnnotatedType(t, annotation, optional, circularTrace)
}); err != nil {
return reflect.Value{}, err
}
Expand All @@ -265,7 +269,7 @@ func (injector *Injector) resolveType(t reflect.Type, annotation string, optiona
}

if !final.IsValid() {
if final, err = injector.internalResolveType(t, annotation, optional, circularTrace); err != nil {
if final, err = injector.createInstanceOfAnnotatedType(t, annotation, optional, circularTrace); err != nil {
return reflect.Value{}, err
}
}
Expand All @@ -274,22 +278,22 @@ func (injector *Injector) resolveType(t reflect.Type, annotation string, optiona
return reflect.Value{}, fmtErrorf("can not resolve %q", t.String())
}

final = injector.intercept(final, t)

return final, nil
return injector.intercept(final, t)
}

func (injector *Injector) intercept(final reflect.Value, t reflect.Type) reflect.Value {
func (injector *Injector) intercept(final reflect.Value, t reflect.Type) (reflect.Value, error) {
for _, interceptor := range injector.interceptor[t] {
of := final
final = reflect.New(interceptor)
injector.requestInjection(final.Interface(), traceCircular)
if err := injector.requestInjection(final.Interface(), traceCircular); err != nil {
return reflect.Value{}, err
}
final.Elem().Field(0).Set(of)
}
if injector.parent != nil {
return injector.parent.intercept(final, t)
}
return final
return final, nil
}

type errUnbound struct {
Expand All @@ -314,15 +318,15 @@ func (injector *Injector) resolveBinding(binding *Binding, t reflect.Type, optio
if binding.to == t {
return reflect.Value{}, fmtErrorf("circular from %q to %q (annotated with: %q)", t, binding.to, binding.annotatedWith)
}
return injector.resolveType(binding.to, "", optional, circularTrace)
return injector.getInstanceOfTypeWithAnnotation(binding.to, "", binding, optional, circularTrace)
}

return reflect.Value{}, errUnbound{binding: binding, typ: t}
}

// internalResolveType resolves a type request with the current injector
func (injector *Injector) internalResolveType(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) {
if binding := injector.findBinding(t, annotation); binding != nil {
// createInstanceOfAnnotatedType resolves a type request with the current injector
func (injector *Injector) createInstanceOfAnnotatedType(t reflect.Type, annotation string, optional bool, circularTrace []circularTraceEntry) (reflect.Value, error) {
if binding := injector.findBindingForAnnotatedType(t, annotation); binding != nil {
r, err := injector.resolveBinding(binding, t, optional, circularTrace)
// todo: go 1.13/1.14: if err == nil || !errors.As(err, new(errUnbound)) {
if err == nil {
Expand All @@ -333,7 +337,7 @@ func (injector *Injector) internalResolveType(t reflect.Type, annotation string,

// todo: proper testcases
if annotation != "" {
return injector.resolveType(binding.typeof, "", false, circularTrace)
return injector.getInstanceOfTypeWithAnnotation(binding.typeof, "", binding, false, circularTrace)
}
}

Expand Down Expand Up @@ -420,12 +424,12 @@ func (injector *Injector) createProvider(t reflect.Type, annotation string, opti

// multibindings
if res.Elem().Kind() == reflect.Slice {
return ret(injector.internalResolveType(t.Out(0), annotation, optional, circularTrace))
return ret(injector.createInstanceOfAnnotatedType(t.Out(0), annotation, optional, circularTrace))
}

// mapbindings
if res.Elem().Kind() == reflect.Map && res.Elem().Type().Key().Kind() == reflect.String {
return ret(injector.internalResolveType(t.Out(0), annotation, optional, circularTrace))
return ret(injector.createInstanceOfAnnotatedType(t.Out(0), annotation, optional, circularTrace))
}

r := ret(injector.getInstance(t.Out(0), annotation, circularTrace))
Expand Down Expand Up @@ -575,17 +579,6 @@ func (injector *Injector) resolveMapbinding(t reflect.Type, annotation string, o
return reflect.MakeMap(t), nil
}

// lookupBinding search a binding with the corresponding annotation
func (injector *Injector) lookupBinding(t reflect.Type, annotation string) *Binding {
for _, binding := range injector.bindings[t] {
if binding.annotatedWith == annotation {
return binding
}
}

return nil
}

// BindMulti binds multiple concrete types to the same abstract type / interface
func (injector *Injector) BindMulti(what interface{}) *Binding {
bindtype := reflect.TypeOf(what)
Expand Down Expand Up @@ -738,7 +731,7 @@ func (injector *Injector) requestInjection(object interface{}, circularTrace []c
}
tag = strings.Split(tag, ",")[0]

instance, err := injector.resolveType(field.Type(), tag, optional, circularTrace)
instance, err := injector.getInstanceOfTypeWithAnnotation(field.Type(), tag, nil, optional, circularTrace)
if err != nil {
return wrapErr(err)
}
Expand Down
38 changes: 38 additions & 0 deletions multi_dingo_test.go
Expand Up @@ -232,3 +232,41 @@ func TestMapBindingProvider(t *testing.T) {
assert.Equal(t, "testkey2 instance", testmap["testkey2"]())
assert.Equal(t, "testkey3 instance", testmap["testkey3"]())
}

func TestMapBindingSingleton(t *testing.T) {
injector, err := NewInjector()
assert.NoError(t, err)

injector.BindMap(new(mapBindInterface), "a").To("a").In(Singleton)
injector.BindMap(new(mapBindInterface), "b").To("b")

i, err := injector.GetInstance(new(mapBindTest1))
assert.NoError(t, err)

first := i.(*mapBindTest1).Mbp()["a"]
second := i.(*mapBindTest1).Mbp()["a"]

assert.True(t, first == second)

first = i.(*mapBindTest1).Mbp()["b"]
second = i.(*mapBindTest1).Mbp()["b"]

assert.False(t, first == second)
}

func TestMultiBindingSingleton(t *testing.T) {
injector, err := NewInjector()
assert.NoError(t, err)

injector.BindMulti(new(mapBindInterface)).To("a").In(Singleton)

i, err := injector.GetInstance(new(multiBindTest))
assert.NoError(t, err)
first := i.(*multiBindTest).Mb[0]

i, err = injector.GetInstance(new(multiBindTest))
assert.NoError(t, err)
second := i.(*multiBindTest).Mb[0]

assert.Same(t, first, second)
}
39 changes: 39 additions & 0 deletions scope_test.go
Expand Up @@ -109,3 +109,42 @@ func TestScopeWithSubDependencies(t *testing.T) {
})
}
}

type inheritedScopeIface interface{}
type inheritedScopeStruct struct{}
type inheritedScopeInjected struct {
i inheritedScopeIface
s *inheritedScopeStruct
}

func (s *inheritedScopeInjected) Inject(ss *inheritedScopeStruct, si inheritedScopeIface) {
s.s = ss
s.i = si
}

func TestInheritedScope(t *testing.T) {
injector, err := NewInjector()
assert.NoError(t, err)

injector.Bind(new(inheritedScopeStruct)).In(ChildSingleton)
injector.Bind(new(inheritedScopeIface)).To(new(inheritedScopeStruct))

injector, err = injector.Child()
assert.NoError(t, err)

i, err := injector.GetInstance(new(inheritedScopeInjected))
assert.NoError(t, err)
firstS := i.(*inheritedScopeInjected)
i, err = injector.GetInstance(new(inheritedScopeInjected))
assert.NoError(t, err)
secondS := i.(*inheritedScopeInjected)
assert.Same(t, firstS.s, secondS.s)

i, err = injector.GetInstance(new(inheritedScopeInjected))
assert.NoError(t, err)
firstI := i.(*inheritedScopeInjected)
i, err = injector.GetInstance(new(inheritedScopeInjected))
assert.NoError(t, err)
secondI := i.(*inheritedScopeInjected)
assert.Same(t, firstI.i, secondI.i)
}

0 comments on commit de4bf3b

Please sign in to comment.