Skip to content

Commit

Permalink
test: add a test for cyclic merges
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chadwell <me@jedevc.com>
  • Loading branch information
jedevc committed Jan 16, 2024
1 parent 6b5891d commit 9290b84
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions solver/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3090,6 +3090,57 @@ func TestMergedEdgesLookup(t *testing.T) {
}
}

func TestMergedEdgesCycle(t *testing.T) {
t.Parallel()

ctx := context.TODO()

cacheManager := newTrackingCacheManager(NewInMemoryCacheManager())

l := NewSolver(SolverOpt{
ResolveOpFunc: testOpResolver,
DefaultCache: cacheManager,
})
defer l.Close()

j0, err := l.NewJob("j0")
require.NoError(t, err)

defer func() {
if j0 != nil {
j0.Discard()
}
}()

// 2 different vertices, va and vb, both with the same cache key
va := vtxSumMult(2, vtxOpt{inputs: []Edge{
{Vertex: vtxConst(3, vtxOpt{})},
}})
vb := vtxSumMult(2, vtxOpt{inputs: []Edge{
{Vertex: vtxConst(3, vtxOpt{})},
}})

// 4 edges va[0], va[1], vb[0], vb[1]
// by ordering them like this, we try and trigger merge va[0]->vb[0] and
// vb[1]->va[1] to cause a cycle
g := Edge{
Vertex: vtxSumMult(1, vtxOpt{inputs: []Edge{
{Vertex: va, Index: 1}, // 2 * 3 = 6
{Vertex: vb, Index: 0}, // 2 + 3 = 5
{Vertex: va, Index: 0}, // 2 + 3 = 5
{Vertex: vb, Index: 1}, // 2 * 3 = 6
}}), // 1 + 6 + 5 + 5 + 6 = 23
}
g.Vertex.(*vertexSumMult).setupCallCounters()

res, err := j0.Build(ctx, g)
require.NoError(t, err)
require.Equal(t, 23, unwrapInt(res))

require.NoError(t, j0.Discard())
j0 = nil
}

func TestCacheLoadError(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -3432,6 +3483,8 @@ func (v *vertex) setCallCounters(cacheCount, execCount *int64) {
v = vv
case *vertexSum:
v = vv.vertex
case *vertexSumMult:
v = vv.vertex
case *vertexConst:
v = vv.vertex
case *vertexSubBuild:
Expand Down Expand Up @@ -3599,6 +3652,49 @@ func (v *vertexSum) Acquire(ctx context.Context) (ReleaseFunc, error) {
return func() {}, nil
}

func vtxSumMult(v int, opt vtxOpt) *vertexSumMult {
if opt.cacheKeySeed == "" {
opt.cacheKeySeed = fmt.Sprintf("summult-%d-%d", v, len(opt.inputs))
}
if opt.name == "" {
opt.name = opt.cacheKeySeed + "-" + identity.NewID()
}
return &vertexSumMult{vertex: vtx(opt), value: v}
}

type vertexSumMult struct {
*vertex
value int
}

func (v *vertexSumMult) Sys() interface{} {
return v
}

func (v *vertexSumMult) Exec(ctx context.Context, g session.Group, inputs []Result) (outputs []Result, err error) {
if err := v.exec(ctx, inputs); err != nil {
return nil, err
}
s := v.value
p := v.value
for _, inp := range inputs {
r, ok := inp.Sys().(*dummyResult)
if !ok {
return nil, errors.Errorf("invalid input type: %T", inp.Sys())
}
s += r.intValue
p *= r.intValue
}
return []Result{
&dummyResult{id: identity.NewID(), intValue: s},
&dummyResult{id: identity.NewID(), intValue: p},
}, nil
}

func (v *vertexSumMult) Acquire(ctx context.Context) (ReleaseFunc, error) {
return func() {}, nil
}

func vtxSubBuild(g Edge, opt vtxOpt) *vertexSubBuild {
if opt.cacheKeySeed == "" {
opt.cacheKeySeed = fmt.Sprintf("sum-%s", identity.NewID())
Expand Down

0 comments on commit 9290b84

Please sign in to comment.