Skip to content

Commit

Permalink
extract Container as a interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mylxsw committed Dec 8, 2019
1 parent 425f175 commit c8ee6e3
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 52 deletions.
92 changes: 46 additions & 46 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Entity struct {
index int // the index in the container

prototype bool
c *Container
c *containerImpl
}

// Value instance value if not initialized
Expand Down Expand Up @@ -78,24 +78,24 @@ func (e *Entity) createValue(provider func() []*Entity) (interface{}, error) {
return returnValues[0].Interface(), nil
}

// Container is a dependency injection container
type Container struct {
// containerImpl is a dependency injection container
type containerImpl struct {
lock sync.RWMutex

objects map[interface{}]*Entity
objectSlices []*Entity

parent *Container
parent Container
}

// New create a new container
func New() *Container {
cc := &Container{
func New() Container {
cc := &containerImpl{
objects: make(map[interface{}]*Entity),
objectSlices: make([]*Entity, 0),
}

cc.MustSingleton(func() *Container {
cc.MustSingleton(func() Container {
return cc
})

Expand All @@ -107,13 +107,13 @@ func New() *Container {
}

// NewWithContext create a new container with context support
func NewWithContext(ctx context.Context) *Container {
cc := &Container{
func NewWithContext(ctx context.Context) Container {
cc := &containerImpl{
objects: make(map[interface{}]*Entity),
objectSlices: make([]*Entity, 0),
}

cc.MustSingleton(func() *Container {
cc.MustSingleton(func() Container {
return cc
})

Expand All @@ -126,78 +126,78 @@ func NewWithContext(ctx context.Context) *Container {

// Extend create a new container and it's parent is supplied container
// If can not found a binding from current container, it will search from parents
func Extend(c *Container) *Container {
cc := &Container{
func Extend(c Container) Container {
cc := &containerImpl{
objects: make(map[interface{}]*Entity),
objectSlices: make([]*Entity, 0),
parent: c,
}

cc.MustSingleton(func() *Container {
cc.MustSingleton(func() Container {
return cc
})

return cc
}

// ExtendFrom extend from a parent Container
func (c *Container) ExtendFrom(parent *Container) {
// ExtendFrom extend from a parent containerImpl
func (c *containerImpl) ExtendFrom(parent Container) {
c.parent = parent
}

// Must if err is not nil, panic it
func (c *Container) Must(err error) {
func (c *containerImpl) Must(err error) {
if err != nil {
panic(err)
}
}

// Prototype bind a prototype
// initialize func(...) (value, error)
func (c *Container) Prototype(initialize interface{}) error {
func (c *containerImpl) Prototype(initialize interface{}) error {
return c.Bind(initialize, true)
}

// MustPrototype bind a prototype, if failed then panic
func (c *Container) MustPrototype(initialize interface{}) {
func (c *containerImpl) MustPrototype(initialize interface{}) {
c.Must(c.Prototype(initialize))
}

// PrototypeWithKey bind a prototype with key
// initialize func(...) (value, error)
func (c *Container) PrototypeWithKey(key interface{}, initialize interface{}) error {
func (c *containerImpl) PrototypeWithKey(key interface{}, initialize interface{}) error {
return c.BindWithKey(key, initialize, true)
}

// MustPrototypeWithKey bind a prototype with key, it failed, then panic
func (c *Container) MustPrototypeWithKey(key interface{}, initialize interface{}) {
func (c *containerImpl) MustPrototypeWithKey(key interface{}, initialize interface{}) {
c.Must(c.PrototypeWithKey(key, initialize))
}

// Singleton bind a singleton
// initialize func(...) (value, error)
func (c *Container) Singleton(initialize interface{}) error {
func (c *containerImpl) Singleton(initialize interface{}) error {
return c.Bind(initialize, false)
}

// MustSingleton bind a singleton, if bind failed, then panic
func (c *Container) MustSingleton(initialize interface{}) {
func (c *containerImpl) MustSingleton(initialize interface{}) {
c.Must(c.Singleton(initialize))
}

// SingletonWithKey bind a singleton with key
// initialize func(...) (value, error)
func (c *Container) SingletonWithKey(key interface{}, initialize interface{}) error {
func (c *containerImpl) SingletonWithKey(key interface{}, initialize interface{}) error {
return c.BindWithKey(key, initialize, false)
}

// MustSingletonWithKey bind a singleton with key, if failed, then panic
func (c *Container) MustSingletonWithKey(key interface{}, initialize interface{}) {
func (c *containerImpl) MustSingletonWithKey(key interface{}, initialize interface{}) {
c.Must(c.SingletonWithKey(key, initialize))
}

// BindValue bind a value to container
func (c *Container) BindValue(key interface{}, value interface{}) error {
func (c *containerImpl) BindValue(key interface{}, value interface{}) error {
if value == nil {
return ErrInvalidArgs("value is nil")
}
Expand Down Expand Up @@ -226,12 +226,12 @@ func (c *Container) BindValue(key interface{}, value interface{}) error {
}

// MustBindValue bind a value to container, if failed, panic it
func (c *Container) MustBindValue(key interface{}, value interface{}) {
func (c *containerImpl) MustBindValue(key interface{}, value interface{}) {
c.Must(c.BindValue(key, value))
}

// ServiceProvider create a provider from initializes
func (c *Container) ServiceProvider(initializes ...interface{}) (func() []*Entity, error) {
func (c *containerImpl) ServiceProvider(initializes ...interface{}) (func() []*Entity, error) {
entities := make([]*Entity, len(initializes))
for i, init := range initializes {
entity, err := c.NewEntity(init, false)
Expand All @@ -248,7 +248,7 @@ func (c *Container) ServiceProvider(initializes ...interface{}) (func() []*Entit
}

// NewEntity create a new entity
func (c *Container) NewEntity(initialize interface{}, prototype bool) (*Entity, error) {
func (c *containerImpl) NewEntity(initialize interface{}, prototype bool) (*Entity, error) {
if !reflect.ValueOf(initialize).IsValid() {
return nil, ErrInvalidArgs("initialize is nil")
}
Expand All @@ -262,7 +262,7 @@ func (c *Container) NewEntity(initialize interface{}, prototype bool) (*Entity,
return c.newEntity(typ, typ, initialize, prototype), nil
}

func (c *Container) newEntity(key interface{}, typ reflect.Type, initialize interface{}, prototype bool) *Entity {
func (c *containerImpl) newEntity(key interface{}, typ reflect.Type, initialize interface{}, prototype bool) *Entity {
entity := Entity{
initializeFunc: initialize,
key: key,
Expand All @@ -277,7 +277,7 @@ func (c *Container) newEntity(key interface{}, typ reflect.Type, initialize inte

// Bind bind a initialize for object
// initialize func(...) (value, error)
func (c *Container) Bind(initialize interface{}, prototype bool) error {
func (c *containerImpl) Bind(initialize interface{}, prototype bool) error {
if !reflect.ValueOf(initialize).IsValid() {
return ErrInvalidArgs("initialize is nil")
}
Expand All @@ -292,13 +292,13 @@ func (c *Container) Bind(initialize interface{}, prototype bool) error {
}

// MustBind bind a initialize, if failed then panic
func (c *Container) MustBind(initialize interface{}, prototype bool) {
func (c *containerImpl) MustBind(initialize interface{}, prototype bool) {
c.Must(c.Bind(initialize, prototype))
}

// BindWithKey bind a initialize for object with a key
// initialize func(...) (value, error)
func (c *Container) BindWithKey(key interface{}, initialize interface{}, prototype bool) error {
func (c *containerImpl) BindWithKey(key interface{}, initialize interface{}, prototype bool) error {
if !reflect.ValueOf(initialize).IsValid() {
return ErrInvalidArgs("initialize is nil")
}
Expand All @@ -312,25 +312,25 @@ func (c *Container) BindWithKey(key interface{}, initialize interface{}, prototy
}

// MustBindWithKey bind a initialize for object with a key, if failed then panic
func (c *Container) MustBindWithKey(key interface{}, initialize interface{}, prototype bool) {
func (c *containerImpl) MustBindWithKey(key interface{}, initialize interface{}, prototype bool) {
c.Must(c.BindWithKey(key, initialize, prototype))
}

// Resolve inject args for func by callback
// callback func(...)
func (c *Container) Resolve(callback interface{}) error {
func (c *containerImpl) Resolve(callback interface{}) error {
_, err := c.Call(callback)
return err
}

// MustResolve inject args for func by callback
func (c *Container) MustResolve(callback interface{}) {
func (c *containerImpl) MustResolve(callback interface{}) {
c.Must(c.Resolve(callback))
}

// ResolveWithError inject args for func by callback
// callback func(...) error
func (c *Container) ResolveWithError(callback interface{}) error {
func (c *containerImpl) ResolveWithError(callback interface{}) error {
results, err := c.Call(callback)
if err != nil {
return newResolveError(err)
Expand All @@ -346,7 +346,7 @@ func (c *Container) ResolveWithError(callback interface{}) error {
}

// CallWithProvider execute the callback with extra service provider
func (c *Container) CallWithProvider(callback interface{}, provider func() []*Entity) ([]interface{}, error) {
func (c *containerImpl) CallWithProvider(callback interface{}, provider func() []*Entity) ([]interface{}, error) {
callbackValue := reflect.ValueOf(callback)
if !callbackValue.IsValid() {
return nil, ErrInvalidArgs("callback is nil")
Expand All @@ -367,16 +367,16 @@ func (c *Container) CallWithProvider(callback interface{}, provider func() []*En
}

// Call call a callback function and return it's results
func (c *Container) Call(callback interface{}) ([]interface{}, error) {
func (c *containerImpl) Call(callback interface{}) ([]interface{}, error) {
return c.CallWithProvider(callback, nil)
}

// Get get instance by key from container
func (c *Container) Get(key interface{}) (interface{}, error) {
func (c *containerImpl) Get(key interface{}) (interface{}, error) {
return c.get(key, nil)
}

func (c *Container) get(key interface{}, provider func() []*Entity) (interface{}, error) {
func (c *containerImpl) get(key interface{}, provider func() []*Entity) (interface{}, error) {
keyReflectType, ok := key.(reflect.Type)
if !ok {
keyReflectType = reflect.TypeOf(key)
Expand Down Expand Up @@ -410,14 +410,14 @@ func (c *Container) get(key interface{}, provider func() []*Entity) (interface{}
}

if c.parent != nil {
return c.parent.get(key, nil)
return c.parent.Get(key)
}

return nil, ErrObjectNotFound(fmt.Sprintf("key=%s", key))
}

// MustGet get instance by key from container
func (c *Container) MustGet(key interface{}) interface{} {
func (c *containerImpl) MustGet(key interface{}) interface{} {
res, err := c.Get(key)
if err != nil {
panic(err)
Expand All @@ -426,7 +426,7 @@ func (c *Container) MustGet(key interface{}) interface{} {
return res
}

func (c *Container) bindWith(key interface{}, typ reflect.Type, initialize interface{}, prototype bool) error {
func (c *containerImpl) bindWith(key interface{}, typ reflect.Type, initialize interface{}, prototype bool) error {
c.lock.Lock()
defer c.lock.Unlock()

Expand All @@ -443,7 +443,7 @@ func (c *Container) bindWith(key interface{}, typ reflect.Type, initialize inter
return nil
}

func (c *Container) funcArgs(t reflect.Type, provider func() []*Entity) ([]reflect.Value, error) {
func (c *containerImpl) funcArgs(t reflect.Type, provider func() []*Entity) ([]reflect.Value, error) {
argsSize := t.NumIn()
argValues := make([]reflect.Value, argsSize)
for i := 0; i < argsSize; i++ {
Expand All @@ -459,7 +459,7 @@ func (c *Container) funcArgs(t reflect.Type, provider func() []*Entity) ([]refle
return argValues, nil
}

func (c *Container) instanceOfType(t reflect.Type, provider func() []*Entity) (reflect.Value, error) {
func (c *containerImpl) instanceOfType(t reflect.Type, provider func() []*Entity) (reflect.Value, error) {
arg, err := c.get(t, provider)
if err != nil {
return reflect.Value{}, ErrArgNotInstanced(err.Error())
Expand All @@ -469,7 +469,7 @@ func (c *Container) instanceOfType(t reflect.Type, provider func() []*Entity) (r
}

// Keys return all keys
func (c *Container) Keys() []interface{} {
func (c *containerImpl) Keys() []interface{} {
c.lock.RLock()
defer c.lock.RUnlock()

Expand Down
10 changes: 5 additions & 5 deletions container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestPrototype(t *testing.T) {
c := container.New()

c.MustBindValue("conn_str", "root:root@/my_db?charset=utf8")
c.MustSingleton(func(c *container.Container) (*UserRepo, error) {
c.MustSingleton(func(c container.Container) (*UserRepo, error) {
connStr, err := c.Get("conn_str")
if err != nil {
return nil, err
Expand Down Expand Up @@ -88,7 +88,7 @@ func TestPrototype(t *testing.T) {
}

{
c.MustResolve(func(cc *container.Container) {
c.MustResolve(func(cc container.Container) {
userService, err := c.Get((*UserService)(nil))
if err != nil {
t.Error(err)
Expand All @@ -105,7 +105,7 @@ func TestPrototype(t *testing.T) {
func TestInterfaceInjection(t *testing.T) {
c := container.New()
c.MustBindValue("conn_str", "root:root@/my_db?charset=utf8")
c.MustSingleton(func(c *container.Container) (*UserRepo, error) {
c.MustSingleton(func(c container.Container) (*UserRepo, error) {
connStr, err := c.Get("conn_str")
if err != nil {
return nil, err
Expand Down Expand Up @@ -167,7 +167,7 @@ type TestObject struct {
func TestWithProvider(t *testing.T) {
c := container.New()
c.MustBindValue("conn_str", "root:root@/my_db?charset=utf8")
c.MustSingleton(func(c *container.Container) (*UserRepo, error) {
c.MustSingleton(func(c container.Container) (*UserRepo, error) {
connStr, err := c.Get("conn_str")
if err != nil {
return nil, err
Expand Down Expand Up @@ -202,7 +202,7 @@ func TestWithProvider(t *testing.T) {
func TestExtend(t *testing.T) {
c := container.New()
c.MustBindValue("conn_str", "root:root@/my_db?charset=utf8")
c.MustSingleton(func(c *container.Container) (*UserRepo, error) {
c.MustSingleton(func(c container.Container) (*UserRepo, error) {
connStr, err := c.Get("conn_str")
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit c8ee6e3

Please sign in to comment.