diff --git a/alpha/model/model.go b/alpha/model/model.go index c20ad4cd4..e627df8c3 100644 --- a/alpha/model/model.go +++ b/alpha/model/model.go @@ -11,6 +11,7 @@ import ( "github.com/h2non/filetype/matchers" "github.com/h2non/filetype/types" svg "github.com/h2non/go-is-svg" + "k8s.io/apimachinery/pkg/util/sets" "github.com/operator-framework/operator-registry/alpha/property" ) @@ -184,7 +185,7 @@ func (c *Channel) Validate() error { } if len(c.Bundles) > 0 { - if _, err := c.Head(); err != nil { + if err := c.validateReplacesChain(); err != nil { result.subErrors = append(result.subErrors, err) } } @@ -203,6 +204,51 @@ func (c *Channel) Validate() error { return result.orNil() } +// validateReplacesChain checks the replaces chain of a channel. +// Specifically the following rules must be followed: +// 1. There must be exactly 1 channel head. +// 2. Beginning at the head, the replaces chain must reach all non-skipped entries. +// Non-skipped entries are defined as entries that are not skipped by any other entry in the channel. +// 3. There must be no cycles in the replaces chain. +// 4. The tail entry in the replaces chain is permitted to replace a non-existent entry. +func (c *Channel) validateReplacesChain() error { + head, err := c.Head() + if err != nil { + return err + } + + allBundles := sets.NewString() + skippedBundles := sets.NewString() + for _, b := range c.Bundles { + allBundles = allBundles.Insert(b.Name) + skippedBundles = skippedBundles.Insert(b.Skips...) + } + + chainFrom := map[string][]string{} + replacesChainFromHead := sets.NewString(head.Name) + cur := head + for cur != nil { + if _, ok := chainFrom[cur.Name]; !ok { + chainFrom[cur.Name] = []string{cur.Name} + } + for k := range chainFrom { + chainFrom[k] = append(chainFrom[k], cur.Replaces) + } + if replacesChainFromHead.Has(cur.Replaces) { + return fmt.Errorf("detected cycle in replaces chain of upgrade graph: %s", strings.Join(chainFrom[cur.Replaces], " -> ")) + } + replacesChainFromHead = replacesChainFromHead.Insert(cur.Replaces) + cur = c.Bundles[cur.Replaces] + } + + strandedBundles := allBundles.Difference(replacesChainFromHead).Difference(skippedBundles).List() + if len(strandedBundles) > 0 { + return fmt.Errorf("channel contains one or more stranded bundles: %s", strings.Join(strandedBundles, ", ")) + } + + return nil +} + type Bundle struct { Package *Package Channel *Channel diff --git a/alpha/model/model_test.go b/alpha/model/model_test.go index effa889d0..463ee6173 100644 --- a/alpha/model/model_test.go +++ b/alpha/model/model_test.go @@ -118,6 +118,59 @@ func TestChannelHead(t *testing.T) { } } +func TestValidReplacesChain(t *testing.T) { + type spec struct { + name string + ch Channel + assertion require.ErrorAssertionFunc + } + specs := []spec{ + { + name: "Success/Valid", + ch: Channel{Bundles: map[string]*Bundle{ + "anakin.v0.0.1": {Name: "anakin.v0.0.1"}, + "anakin.v0.0.2": {Name: "anakin.v0.0.2", Skips: []string{"anakin.v0.0.1"}}, + "anakin.v0.0.3": {Name: "anakin.v0.0.3", Skips: []string{"anakin.v0.0.2"}}, + "anakin.v0.0.4": {Name: "anakin.v0.0.4", Replaces: "anakin.v0.0.3"}, + }}, + assertion: require.NoError, + }, + { + name: "Error/CycleNoHops", + ch: Channel{Bundles: map[string]*Bundle{ + "anakin.v0.0.4": {Name: "anakin.v0.0.4", Replaces: "anakin.v0.0.4"}, + "anakin.v0.0.5": {Name: "anakin.v0.0.5", Replaces: "anakin.v0.0.4"}, + }}, + assertion: hasError(`detected cycle in replaces chain of upgrade graph: anakin.v0.0.4 -> anakin.v0.0.4`), + }, + { + name: "Error/CycleMultipleHops", + ch: Channel{Bundles: map[string]*Bundle{ + "anakin.v0.0.1": {Name: "anakin.v0.0.1", Replaces: "anakin.v0.0.3"}, + "anakin.v0.0.2": {Name: "anakin.v0.0.2", Replaces: "anakin.v0.0.1"}, + "anakin.v0.0.3": {Name: "anakin.v0.0.3", Replaces: "anakin.v0.0.2"}, + "anakin.v0.0.4": {Name: "anakin.v0.0.4", Replaces: "anakin.v0.0.3"}, + }}, + assertion: hasError(`detected cycle in replaces chain of upgrade graph: anakin.v0.0.3 -> anakin.v0.0.2 -> anakin.v0.0.1 -> anakin.v0.0.3`), + }, + { + name: "Error/Stranded", + ch: Channel{Bundles: map[string]*Bundle{ + "anakin.v0.0.1": {Name: "anakin.v0.0.1"}, + "anakin.v0.0.2": {Name: "anakin.v0.0.2", Replaces: "anakin.v0.0.1"}, + "anakin.v0.0.3": {Name: "anakin.v0.0.3", Skips: []string{"anakin.v0.0.2"}}, + }}, + assertion: hasError(`channel contains one or more stranded bundles: anakin.v0.0.1`), + }, + } + for _, s := range specs { + t.Run(s.name, func(t *testing.T) { + err := s.ch.validateReplacesChain() + s.assertion(t, err) + }) + } +} + func hasError(expectedError string) require.ErrorAssertionFunc { return func(t require.TestingT, actualError error, args ...interface{}) { if stdt, ok := t.(*testing.T); ok {