Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions pkg/sql/schemachanger/rel/query_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,41 @@ func (p *queryBuilder) setSubQueryDepths(entitySlots []slotIdx) {
}

func (p *queryBuilder) setSubqueryDepth(s *subQuery, entitySlots []slotIdx) {
var max int
// First, figure out which entity binds each variable.
// slotToEntityProvider maps slot -> entity slot that provides it.
slotToEntityProvider := make(map[int]int)
for _, f := range p.facts {
if p.slotIsEntity[f.variable] {
// This fact is about an entity, so it might bind a variable
slotToEntityProvider[int(f.value)] = int(f.variable)
}
}

var maxEntitySlot int
s.inputSlotMappings.ForEach(func(key, _ int) {
if p.slotIsEntity[key] && key > max {
max = key
if p.slotIsEntity[key] {
if key > maxEntitySlot {
maxEntitySlot = key
}
} else {
// For non-entity variables, find which entity provides them
provider, hasProvider := slotToEntityProvider[key]
if hasProvider {
if provider > maxEntitySlot {
maxEntitySlot = provider
}
}
}
})

got := sort.Search(len(entitySlots), func(i int) bool {
return int(entitySlots[i]) >= max
return int(entitySlots[i]) >= maxEntitySlot
})
if got == len(entitySlots) {
panic(errors.AssertionFailedf("failed to find maximum entity in entitySlots: %v not in %v",
max, entitySlots))
maxEntitySlot, entitySlots))
}

s.depth = queryDepth(got + 1)
}

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/schemachanger/rel/query_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ func (ec *evalContext) maybeVisitSubqueries() (nextSubQuery int, done bool, erro
func (ec *evalContext) visitSubquery(query int) (done bool, _ error) {
sub := ec.q.notJoins[query]
sec := sub.query.getEvalContext()

defer sub.query.putEvalContext(sec)
defer func() { // reset the slots populated to run the subquery
sub.inputSlotMappings.ForEach(func(_, subSlot int) {
Expand Down
185 changes: 185 additions & 0 deletions pkg/sql/schemachanger/rel/rel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,188 @@ func TestConcurrentQueryInDifferentDatabases(t *testing.T) {
}
require.NoError(t, g.Wait())
}

type notJoinTestAttr string

func (a notJoinTestAttr) String() string { return string(a) }

// TestNotJoinSubqueryDepthWithNonEntityVariables tests that notJoin subqueries
// execute at the correct depth when they depend on non-entity variables that
// are bound by entities at different join depths.
//
// This test catches a bug where notJoin subqueries would execute too early,
// before their required non-entity variables were bound. The fix ensures that
// we track which entity provides each variable through facts, so notJoin
// subqueries wait until all their input variables are available.
func TestNotJoinSubqueryDepthWithNonEntityVariables(t *testing.T) {
defer leaktest.AfterTest(t)()

type FirstEntity struct {
ID int
SharedID int
}
type SecondEntity struct {
ID int
SharedID int
Value string
}
type ThirdEntity struct {
ID int
SharedID int
Flag int
}

schema := rel.MustSchema("test_notjoin_depth",
rel.EntityMapping(reflect.TypeOf((*FirstEntity)(nil)),
rel.EntityAttr(notJoinTestAttr("id"), "ID"),
rel.EntityAttr(notJoinTestAttr("shared_id"), "SharedID"),
),
rel.EntityMapping(reflect.TypeOf((*SecondEntity)(nil)),
rel.EntityAttr(notJoinTestAttr("id"), "ID"),
rel.EntityAttr(notJoinTestAttr("shared_id"), "SharedID"),
rel.EntityAttr(notJoinTestAttr("value"), "Value"),
),
rel.EntityMapping(reflect.TypeOf((*ThirdEntity)(nil)),
rel.EntityAttr(notJoinTestAttr("id"), "ID"),
rel.EntityAttr(notJoinTestAttr("shared_id"), "SharedID"),
rel.EntityAttr(notJoinTestAttr("flag"), "Flag"),
),
)

// Define a notJoin rule that depends on a non-entity variable (shared_id).
// This rule checks if there's no ThirdEntity with the given shared_id and flag=1.
noThirdWithFlag := schema.DefNotJoin1("no_third_with_flag", "shared_id_var", func(
sharedIDVar rel.Var,
) rel.Clauses {
return rel.Clauses{
rel.Var("third").Type((*ThirdEntity)(nil)),
rel.Var("third").AttrEqVar(notJoinTestAttr("shared_id"), sharedIDVar),
rel.Var("third").AttrEq(notJoinTestAttr("flag"), 1),
}
})

first1 := &FirstEntity{ID: 1, SharedID: 100}
second1 := &SecondEntity{ID: 2, SharedID: 100, Value: "test"}
third1 := &ThirdEntity{ID: 3, SharedID: 100, Flag: 0}
third2 := &ThirdEntity{ID: 4, SharedID: 200, Flag: 1}

db, err := rel.NewDatabase(schema,
rel.Index{Attrs: []rel.Attr{rel.Type}},
rel.Index{Attrs: []rel.Attr{rel.Self}},
rel.Index{Attrs: []rel.Attr{notJoinTestAttr("id")}},
rel.Index{Attrs: []rel.Attr{notJoinTestAttr("shared_id")}},
rel.Index{Attrs: []rel.Attr{notJoinTestAttr("flag")}},
rel.Index{Attrs: []rel.Attr{notJoinTestAttr("value")}},
)
require.NoError(t, err)
require.NoError(t, db.Insert(first1))
require.NoError(t, db.Insert(second1))
require.NoError(t, db.Insert(third1))
require.NoError(t, db.Insert(third2))

// Test case 1: Query where shared_id is bound by SecondEntity (at depth 2).
// The notJoin should execute after SecondEntity is joined.
t.Run("notjoin_executes_after_second_entity", func(t *testing.T) {
q, err := rel.NewQuery(schema,
// FirstEntity is joined first (depth 1).
rel.Var("first").Type((*FirstEntity)(nil)),
rel.Var("first").AttrEq(notJoinTestAttr("id"), 1),
// SecondEntity is joined second (depth 2) and binds shared_id_var.
rel.Var("second").Type((*SecondEntity)(nil)),
rel.Var("second").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"),
// Join FirstEntity and SecondEntity on shared_id.
rel.Var("first").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"),
// This notJoin depends on shared_id_var, which is bound by SecondEntity
// It should execute at depth 2 or later, not before
noThirdWithFlag("shared_id_var"),
)
require.NoError(t, err)

var results [][]interface{}
err = q.Iterate(db, nil, func(r rel.Result) error {
results = append(results, []interface{}{
r.Var(rel.Var("first")),
r.Var(rel.Var("second")),
r.Var(rel.Var("shared_id_var")),
})
return nil
})
require.NoError(t, err)
// Should find the combination where shared_id=100.
require.Len(t, results, 1)
require.Equal(t, first1, results[0][0])
require.Equal(t, second1, results[0][1])
require.Equal(t, 100, results[0][2])
})

// Test case 2: Query where shared_id would cause the notJoin to fail.
t.Run("notjoin_filters_results_correctly", func(t *testing.T) {
// Add a FirstEntity with shared_id=200.
first2 := &FirstEntity{ID: 5, SharedID: 200}
second2 := &SecondEntity{ID: 6, SharedID: 200, Value: "test2"}
require.NoError(t, db.Insert(first2))
require.NoError(t, db.Insert(second2))

q, err := rel.NewQuery(schema,
rel.Var("first").Type((*FirstEntity)(nil)),
rel.Var("first").AttrEq(notJoinTestAttr("id"), 5),
rel.Var("second").Type((*SecondEntity)(nil)),
rel.Var("second").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"),
rel.Var("first").AttrEqVar(notJoinTestAttr("shared_id"), "shared_id_var"),
noThirdWithFlag("shared_id_var"),
)
require.NoError(t, err)

var results [][]interface{}
err = q.Iterate(db, nil, func(r rel.Result) error {
results = append(results, []interface{}{
r.Var(rel.Var("first")),
r.Var(rel.Var("second")),
})
return nil
})
require.NoError(t, err)
// Should find no results because third2 has shared_id=200 and flag=1.
require.Empty(t, results)
})

// Test case 3: Complex case with non-entity variable bound at depth 3.
t.Run("notjoin_with_variable_bound_at_depth_3", func(t *testing.T) {
// Define a more complex notJoin that uses two variables.
complexNotJoin := schema.DefNotJoin2("complex_not_join", "sid", "val", func(
sidVar, valVar rel.Var,
) rel.Clauses {
return rel.Clauses{
rel.Var("e").Type((*SecondEntity)(nil)),
rel.Var("e").AttrEqVar(notJoinTestAttr("shared_id"), sidVar),
rel.Var("e").AttrEqVar(notJoinTestAttr("value"), valVar),
}
})

q, err := rel.NewQuery(schema,
// Depth 1
rel.Var("f1").Type((*FirstEntity)(nil)),
// Depth 2
rel.Var("f2").Type((*FirstEntity)(nil)),
rel.Var("f2").AttrNeq(notJoinTestAttr("id"), 1),
// Depth 3 - this is where shared_id and value are bound.
rel.Var("s").Type((*SecondEntity)(nil)),
rel.Var("s").AttrEqVar(notJoinTestAttr("shared_id"), "sid"),
rel.Var("s").AttrEqVar(notJoinTestAttr("value"), "val"),
// The notJoin should only execute after depth 3.
complexNotJoin("sid", "val"),
// Add a filter to limit results.
rel.Filter("limit", "f1")(func(e *FirstEntity) bool {
return e.ID == 1
}),
)
require.NoError(t, err)

// This should execute without "unbound variable" errors
// even though the notJoin depends on variables bound at depth 3
err = q.Iterate(db, nil, func(r rel.Result) error {
return nil
})
require.NoError(t, err)
})
}
Loading