Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mylxsw committed May 24, 2020
1 parent 0a55a01 commit 9377311
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
40 changes: 32 additions & 8 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,24 +366,21 @@ func (c *containerImpl) Get(key interface{}) (interface{}, error) {
}

func (c *containerImpl) get(key interface{}, provider func() []*Entity) (interface{}, error) {
keyReflectType, ok := key.(reflect.Type)
if !ok {
keyReflectType = reflect.TypeOf(key)
}
lookupKey := c.buildKeyLookupFunc(key)

c.lock.RLock()
defer c.lock.RUnlock()

if provider != nil {
for _, obj := range provider() {
if obj.key == key || obj.key == keyReflectType {
if lookupKey(obj.key) {
return obj.Value(provider)
}
}
}

for _, obj := range c.objectSlices {
if obj.key == key || obj.key == keyReflectType {
if lookupKey(obj.key) {
return obj.Value(provider)
}
}
Expand All @@ -392,7 +389,34 @@ func (c *containerImpl) get(key interface{}, provider func() []*Entity) (interfa
return c.parent.Get(key)
}

return nil, buildObjectNotFoundError(fmt.Sprintf("key=%s not found", key))
return nil, buildObjectNotFoundError(fmt.Sprintf("key=%#v not found", key))
}

// buildKeyLookupFunc 构建用于查询 key 是否存在的函数
// key 匹配规则为
// 1. matchKey == lookupKey ,则匹配
// 2. matchKey == type(lookupKey) ,则匹配
// 3. 如果 lookupKey 是指向接口的指针,则解析成接口本身,与 matchKey 比较,相等则匹配
func (c *containerImpl) buildKeyLookupFunc(lookupKey interface{}) func(matchKey interface{}) bool {
keyReflectType, ok := lookupKey.(reflect.Type)
if !ok {
keyReflectType = reflect.TypeOf(lookupKey)
}

keyLookupMap := make(map[interface{}]bool)
keyLookupMap[lookupKey] = true
keyLookupMap[keyReflectType] = true

if keyReflectType.Kind() == reflect.Ptr {
typeUnderPointer := keyReflectType.Elem()
if typeUnderPointer.Kind() == reflect.Interface {
keyLookupMap[typeUnderPointer] = true
}
}
return func(key interface{}) bool {
_, ok := keyLookupMap[key]
return ok
}
}

// MustGet get instance by key from container
Expand Down Expand Up @@ -450,7 +474,7 @@ func (c *containerImpl) CanOverride(key interface{}) (bool, error) {

obj, ok := c.objects[key]
if !ok {
return true, buildObjectNotFoundError(fmt.Sprintf("key=%v not found", key))
return true, buildObjectNotFoundError(fmt.Sprintf("key=%#v not found", key))
}

return obj.override, nil
Expand Down
14 changes: 14 additions & 0 deletions container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ func TestPrototype(t *testing.T) {
t.Errorf("test failed: %s", err)
return
}
{
userService, err := c.Get(new(UserService))
if err != nil {
t.Error(err)
return
}

fmt.Println(userService.(*UserService).GetUser())
}
// reflect.TypeOf((*UserService)(nil))
{
userService, err := c.Get(reflect.TypeOf((*UserService)(nil)))
Expand Down Expand Up @@ -298,6 +307,11 @@ func TestContainerImpl_Override(t *testing.T) {
return demo2{}
})

res := c.MustGet(new(InterfaceDemo))
if "demo2" != res.(InterfaceDemo).String() {
t.Error("test failed")
}

c.MustResolve(func(demo InterfaceDemo) {
if "demo2" != demo.String() {
t.Error("test failed")
Expand Down

0 comments on commit 9377311

Please sign in to comment.