diff --git a/solver/scheduler_test.go b/solver/scheduler_test.go index 79497913076d6..c30c8c60a41b4 100644 --- a/solver/scheduler_test.go +++ b/solver/scheduler_test.go @@ -3090,6 +3090,57 @@ func TestMergedEdgesLookup(t *testing.T) { } } +func TestMergedEdgesCycle(t *testing.T) { + t.Parallel() + + ctx := context.TODO() + + cacheManager := newTrackingCacheManager(NewInMemoryCacheManager()) + + l := NewSolver(SolverOpt{ + ResolveOpFunc: testOpResolver, + DefaultCache: cacheManager, + }) + defer l.Close() + + j0, err := l.NewJob("j0") + require.NoError(t, err) + + defer func() { + if j0 != nil { + j0.Discard() + } + }() + + // 2 different vertices, va and vb, both with the same cache key + va := vtxSumMult(2, vtxOpt{inputs: []Edge{ + {Vertex: vtxConst(3, vtxOpt{})}, + }}) + vb := vtxSumMult(2, vtxOpt{inputs: []Edge{ + {Vertex: vtxConst(3, vtxOpt{})}, + }}) + + // 4 edges va[0], va[1], vb[0], vb[1] + // by ordering them like this, we try and trigger merge va[0]->vb[0] and + // vb[1]->va[1] to cause a cycle + g := Edge{ + Vertex: vtxSumMult(1, vtxOpt{inputs: []Edge{ + {Vertex: va, Index: 1}, // 2 * 3 = 6 + {Vertex: vb, Index: 0}, // 2 + 3 = 5 + {Vertex: va, Index: 0}, // 2 + 3 = 5 + {Vertex: vb, Index: 1}, // 2 * 3 = 6 + }}), // 1 + 6 + 5 + 5 + 6 = 23 + } + g.Vertex.(*vertexSumMult).setupCallCounters() + + res, err := j0.Build(ctx, g) + require.NoError(t, err) + require.Equal(t, 23, unwrapInt(res)) + + require.NoError(t, j0.Discard()) + j0 = nil +} + func TestCacheLoadError(t *testing.T) { t.Parallel() @@ -3432,6 +3483,8 @@ func (v *vertex) setCallCounters(cacheCount, execCount *int64) { v = vv case *vertexSum: v = vv.vertex + case *vertexSumMult: + v = vv.vertex case *vertexConst: v = vv.vertex case *vertexSubBuild: @@ -3599,6 +3652,49 @@ func (v *vertexSum) Acquire(ctx context.Context) (ReleaseFunc, error) { return func() {}, nil } +func vtxSumMult(v int, opt vtxOpt) *vertexSumMult { + if opt.cacheKeySeed == "" { + opt.cacheKeySeed = fmt.Sprintf("summult-%d-%d", v, len(opt.inputs)) + } + if opt.name == "" { + opt.name = opt.cacheKeySeed + "-" + identity.NewID() + } + return &vertexSumMult{vertex: vtx(opt), value: v} +} + +type vertexSumMult struct { + *vertex + value int +} + +func (v *vertexSumMult) Sys() interface{} { + return v +} + +func (v *vertexSumMult) Exec(ctx context.Context, g session.Group, inputs []Result) (outputs []Result, err error) { + if err := v.exec(ctx, inputs); err != nil { + return nil, err + } + s := v.value + p := v.value + for _, inp := range inputs { + r, ok := inp.Sys().(*dummyResult) + if !ok { + return nil, errors.Errorf("invalid input type: %T", inp.Sys()) + } + s += r.intValue + p *= r.intValue + } + return []Result{ + &dummyResult{id: identity.NewID(), intValue: s}, + &dummyResult{id: identity.NewID(), intValue: p}, + }, nil +} + +func (v *vertexSumMult) Acquire(ctx context.Context) (ReleaseFunc, error) { + return func() {}, nil +} + func vtxSubBuild(g Edge, opt vtxOpt) *vertexSubBuild { if opt.cacheKeySeed == "" { opt.cacheKeySeed = fmt.Sprintf("sum-%s", identity.NewID())