diff --git a/pkg/sql/schemachanger/rel/query_build.go b/pkg/sql/schemachanger/rel/query_build.go index d82eed89db16..f3e5a959b8b4 100644 --- a/pkg/sql/schemachanger/rel/query_build.go +++ b/pkg/sql/schemachanger/rel/query_build.go @@ -427,19 +427,41 @@ func (p *queryBuilder) setSubQueryDepths(entitySlots []slotIdx) { } func (p *queryBuilder) setSubqueryDepth(s *subQuery, entitySlots []slotIdx) { - var max int + // First, figure out which entity binds each variable. + // slotToEntityProvider maps slot -> entity slot that provides it. + slotToEntityProvider := make(map[int]int) + for _, f := range p.facts { + if p.slotIsEntity[f.variable] { + // This fact is about an entity, so it might bind a variable + slotToEntityProvider[int(f.value)] = int(f.variable) + } + } + + var maxEntitySlot int s.inputSlotMappings.ForEach(func(key, _ int) { - if p.slotIsEntity[key] && key > max { - max = key + if p.slotIsEntity[key] { + if key > maxEntitySlot { + maxEntitySlot = key + } + } else { + // For non-entity variables, find which entity provides them + provider, hasProvider := slotToEntityProvider[key] + if hasProvider { + if provider > maxEntitySlot { + maxEntitySlot = provider + } + } } }) + got := sort.Search(len(entitySlots), func(i int) bool { - return int(entitySlots[i]) >= max + return int(entitySlots[i]) >= maxEntitySlot }) if got == len(entitySlots) { panic(errors.AssertionFailedf("failed to find maximum entity in entitySlots: %v not in %v", - max, entitySlots)) + maxEntitySlot, entitySlots)) } + s.depth = queryDepth(got + 1) } diff --git a/pkg/sql/schemachanger/rel/query_eval.go b/pkg/sql/schemachanger/rel/query_eval.go index 21382a9f47d7..ab882a2c04e2 100644 --- a/pkg/sql/schemachanger/rel/query_eval.go +++ b/pkg/sql/schemachanger/rel/query_eval.go @@ -450,6 +450,7 @@ func (ec *evalContext) maybeVisitSubqueries() (nextSubQuery int, done bool, erro func (ec *evalContext) visitSubquery(query int) (done bool, _ error) { sub := ec.q.notJoins[query] sec := sub.query.getEvalContext() + defer sub.query.putEvalContext(sec) defer func() { // reset the slots populated to run the subquery sub.inputSlotMappings.ForEach(func(_, subSlot int) { diff --git a/pkg/sql/schemachanger/rel/rel_test.go b/pkg/sql/schemachanger/rel/rel_test.go index fc8e5ca97d1c..42a2cbbfb047 100644 --- a/pkg/sql/schemachanger/rel/rel_test.go +++ b/pkg/sql/schemachanger/rel/rel_test.go @@ -577,3 +577,188 @@ func TestConcurrentQueryInDifferentDatabases(t *testing.T) { } require.NoError(t, g.Wait()) } + +type notJoinTestAttr string + +func (a notJoinTestAttr) String() string { return string(a) } + +// TestNotJoinSubqueryDepthWithNonEntityVariables tests that notJoin subqueries +// execute at the correct depth when they depend on non-entity variables that +// are bound by entities at different join depths. +// +// This test catches a bug where notJoin subqueries would execute too early, +// before their required non-entity variables were bound. The fix ensures that +// we track which entity provides each variable through facts, so notJoin +// subqueries wait until all their input variables are available. +func TestNotJoinSubqueryDepthWithNonEntityVariables(t *testing.T) { + defer leaktest.AfterTest(t)() + + type FirstEntity struct { + ID int + SharedID int + } + type SecondEntity struct { + ID int + SharedID int + Value string + } + type ThirdEntity struct { + ID int + SharedID int + Flag int + } + + schema := rel.MustSchema("test_notjoin_depth", + rel.EntityMapping(reflect.TypeOf((*FirstEntity)(nil)), + rel.EntityAttr(notJoinTestAttr("id"), "ID"), + rel.EntityAttr(notJoinTestAttr("shared_id"), "SharedID"), + ), + rel.EntityMapping(reflect.TypeOf((*SecondEntity)(nil)), + rel.EntityAttr(notJoinTestAttr("id"), "ID"), + rel.EntityAttr(notJoinTestAttr("shared_id"), "SharedID"), + rel.EntityAttr(notJoinTestAttr("value"), "Value"), + ), + rel.EntityMapping(reflect.TypeOf((*ThirdEntity)(nil)), + rel.EntityAttr(notJoinTestAttr("id"), "ID"), + rel.EntityAttr(notJoinTestAttr("shared_id"), "SharedID"), + rel.EntityAttr(notJoinTestAttr("flag"), "Flag"), + ), + ) + + // Define a notJoin rule that depends on a non-entity variable (shared_id). + // This rule checks if there's no ThirdEntity with the given shared_id and flag=1. + noThirdWithFlag := schema.DefNotJoin1("no_third_with_flag", "shared_id_var", func( + sharedIDVar rel.Var, + ) rel.Clauses { + return rel.Clauses{ + rel.Var("third").Type((*ThirdEntity)(nil)), + rel.Var("third").AttrEqVar(notJoinTestAttr("shared_id"), sharedIDVar), + rel.Var("third").AttrEq(notJoinTestAttr("flag"), 1), + } + }) + + first1 := &FirstEntity{ID: 1, SharedID: 100} + second1 := &SecondEntity{ID: 2, SharedID: 100, Value: "test"} + third1 := &ThirdEntity{ID: 3, SharedID: 100, Flag: 0} + third2 := &ThirdEntity{ID: 4, SharedID: 200, Flag: 1} + + db, err := rel.NewDatabase(schema, + rel.Index{Attrs: []rel.Attr{rel.Type}}, + rel.Index{Attrs: []rel.Attr{rel.Self}}, + rel.Index{Attrs: []rel.Attr{notJoinTestAttr("id")}}, + rel.Index{Attrs: []rel.Attr{notJoinTestAttr("shared_id")}}, + rel.Index{Attrs: []rel.Attr{notJoinTestAttr("flag")}}, + rel.Index{Attrs: []rel.Attr{notJoinTestAttr("value")}}, + ) + require.NoError(t, err) + require.NoError(t, db.Insert(first1)) + require.NoError(t, db.Insert(second1)) + require.NoError(t, db.Insert(third1)) + require.NoError(t, db.Insert(third2)) + + // Test case 1: Query where shared_id is bound by SecondEntity (at depth 2). + // The notJoin should execute after SecondEntity is joined. + t.Run("notjoin_executes_after_second_entity", func(t *testing.T) { + q, err := rel.NewQuery(schema, + // FirstEntity is joined first (depth 1). + rel.Var("first").Type((*FirstEntity)(nil)), + rel.Var("first").AttrEq(notJoinTestAttr("id"), 1), + // SecondEntity is joined second (depth 2) and binds shared_id_var. + rel.Var("second").Type((*SecondEntity)(nil)), + rel.Var("second").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"), + // Join FirstEntity and SecondEntity on shared_id. + rel.Var("first").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"), + // This notJoin depends on shared_id_var, which is bound by SecondEntity + // It should execute at depth 2 or later, not before + noThirdWithFlag("shared_id_var"), + ) + require.NoError(t, err) + + var results [][]interface{} + err = q.Iterate(db, nil, func(r rel.Result) error { + results = append(results, []interface{}{ + r.Var(rel.Var("first")), + r.Var(rel.Var("second")), + r.Var(rel.Var("shared_id_var")), + }) + return nil + }) + require.NoError(t, err) + // Should find the combination where shared_id=100. + require.Len(t, results, 1) + require.Equal(t, first1, results[0][0]) + require.Equal(t, second1, results[0][1]) + require.Equal(t, 100, results[0][2]) + }) + + // Test case 2: Query where shared_id would cause the notJoin to fail. + t.Run("notjoin_filters_results_correctly", func(t *testing.T) { + // Add a FirstEntity with shared_id=200. + first2 := &FirstEntity{ID: 5, SharedID: 200} + second2 := &SecondEntity{ID: 6, SharedID: 200, Value: "test2"} + require.NoError(t, db.Insert(first2)) + require.NoError(t, db.Insert(second2)) + + q, err := rel.NewQuery(schema, + rel.Var("first").Type((*FirstEntity)(nil)), + rel.Var("first").AttrEq(notJoinTestAttr("id"), 5), + rel.Var("second").Type((*SecondEntity)(nil)), + rel.Var("second").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"), + rel.Var("first").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"), + noThirdWithFlag("shared_id_var"), + ) + require.NoError(t, err) + + var results [][]interface{} + err = q.Iterate(db, nil, func(r rel.Result) error { + results = append(results, []interface{}{ + r.Var(rel.Var("first")), + r.Var(rel.Var("second")), + }) + return nil + }) + require.NoError(t, err) + // Should find no results because third2 has shared_id=200 and flag=1. + require.Empty(t, results) + }) + + // Test case 3: Complex case with non-entity variable bound at depth 3. + t.Run("notjoin_with_variable_bound_at_depth_3", func(t *testing.T) { + // Define a more complex notJoin that uses two variables. + complexNotJoin := schema.DefNotJoin2("complex_not_join", "sid", "val", func( + sidVar, valVar rel.Var, + ) rel.Clauses { + return rel.Clauses{ + rel.Var("e").Type((*SecondEntity)(nil)), + rel.Var("e").AttrEqVar(notJoinTestAttr("shared_id"), sidVar), + rel.Var("e").AttrEqVar(notJoinTestAttr("value"), valVar), + } + }) + + q, err := rel.NewQuery(schema, + // Depth 1 + rel.Var("f1").Type((*FirstEntity)(nil)), + // Depth 2 + rel.Var("f2").Type((*FirstEntity)(nil)), + rel.Var("f2").AttrNeq(notJoinTestAttr("id"), 1), + // Depth 3 - this is where shared_id and value are bound. + rel.Var("s").Type((*SecondEntity)(nil)), + rel.Var("s").AttrEqVar(notJoinTestAttr("shared_id"), "sid"), + rel.Var("s").AttrEqVar(notJoinTestAttr("value"), "val"), + // The notJoin should only execute after depth 3. + complexNotJoin("sid", "val"), + // Add a filter to limit results. + rel.Filter("limit", "f1")(func(e *FirstEntity) bool { + return e.ID == 1 + }), + ) + require.NoError(t, err) + + // This should execute without "unbound variable" errors + // even though the notJoin depends on variables bound at depth 3 + err = q.Iterate(db, nil, func(r rel.Result) error { + return nil + }) + require.NoError(t, err) + }) +}