Skip to content

Commit

Permalink
Implement AddParent() method
Browse files Browse the repository at this point in the history
  • Loading branch information
chirino committed Jun 14, 2021
1 parent 02504b0 commit df8c55d
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
6 changes: 6 additions & 0 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ func (c *Container) Provide(constructor Constructor, options ...ProvideOption) e
return nil
}

// AddParent adds a parent container. Types are resolved from the container,
// it's parents, and ancestors. An error is a cycle is detected in ancestry tree.
func (c *Container) AddParent(parent *Container) error {
return c.schema.addParent(parent.schema)
}

// ProvideValue provides value as is.
func (c *Container) ProvideValue(value Value, options ...ProvideOption) error {
if err := c.provideValue(value, options...); err != nil {
Expand Down
44 changes: 44 additions & 0 deletions container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1306,3 +1306,47 @@ func TestContainer_Cleanup(t *testing.T) {
require.Equal(t, []string{"server", "mux"}, cleanupCalls)
})
}

func TestParentContainer(t *testing.T) {
t.Run("provide in parent and resolve in child", func(t *testing.T) {
parent, err := di.New()
require.NoError(t, err)
require.NotNil(t, parent)
mux := &http.ServeMux{}
err = parent.ProvideValue(mux, di.As(new(http.Handler)))
require.NoError(t, err)
child, err := di.New()
require.NoError(t, err)
require.NotNil(t, child)
err = child.Provide(func(handler http.Handler) *http.Server {
return &http.Server{
Handler: handler,
}
})
require.NoError(t, err)
err = child.AddParent(parent)
require.NoError(t, err)
var server *http.Server
err = child.Resolve(&server)
require.NoError(t, err)
require.Equal(t, fmt.Sprintf("%p", mux), fmt.Sprintf("%p", server.Handler))
})

t.Run("chain errors", func(t *testing.T) {
parent, err := di.New()
require.NoError(t, err)
require.NotNil(t, parent)
err = parent.AddParent(parent)
require.Contains(t, err.Error(), "self cycle detected")
child, err := di.New()
require.NoError(t, err)
require.NotNil(t, child)
err = child.AddParent(parent)
require.NoError(t, err)
err = parent.AddParent(child)
require.Contains(t, err.Error(), "cycle detected")
err = child.AddParent(parent)
require.Contains(t, err.Error(), "parent already chained")
})

}
19 changes: 19 additions & 0 deletions docs/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,22 @@ if err != nil {
container.Cleanup() // file was closed
```

### Container Chains
You can chain containers together so that values can be resolved from a parent container. The child

```go
parent, err := container.New(
di.Provide(NewServer),
di.Provide(NewServeMux),
)

child, err := container.New()

err = child.AddParent(parent)
if err != nil {
// handle error
}

var server *http.Server
err := child.Resolve(&server)
```
42 changes: 41 additions & 1 deletion schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type schema interface {

// schema is a dependency injection schema.
type defaultSchema struct {
parents []*defaultSchema
nodes map[reflect.Type][]*node
cleanups []func()
}
Expand Down Expand Up @@ -43,7 +44,7 @@ func (s *defaultSchema) register(n *node) {

// find finds provideFunc by its reflect.Type and Tags.
func (s *defaultSchema) find(t reflect.Type, tags Tags) (*node, error) {
nodes, ok := s.nodes[t]
nodes, ok := s.list(t)
// type found
if ok {
matched := matchTags(nodes, tags)
Expand Down Expand Up @@ -94,3 +95,42 @@ func (s *defaultSchema) group(t reflect.Type, tags Tags) (*node, error) {
}
return node, nil
}

// list lists all the nodes of its reflect.Type
func (s *defaultSchema) list(t reflect.Type) ([]*node, bool) {
nodes, ok := s.nodes[t]
for _, parent := range s.parents {
if n, o := parent.list(t); o {
nodes = append(nodes, n...)
ok = true
}
}
return nodes, ok
}

// isAncestor returns true if a
func (s *defaultSchema) isAncestor(a *defaultSchema) bool {
for _, parent := range s.parents {
if parent == a {
return true
}
if parent.isAncestor(a) {
return true
}
}
return false
}

func (s *defaultSchema) addParent(parent *defaultSchema) error {
if parent == s {
return fmt.Errorf("self cycle detected")
}
if parent.isAncestor(s) {
return fmt.Errorf("cycle detected")
}
if s.isAncestor(parent) {
return fmt.Errorf("parent already chained")
}
s.parents = append(s.parents, parent)
return nil
}

0 comments on commit df8c55d

Please sign in to comment.