diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index 5724011cfb..efea383839 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -106,7 +106,8 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node, scope } if filtersMoved == 0 { - return n, nil + topJoin = n + return topJoin, nil } if len(condFilters) > 0 { diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go index 5b92bfe8ae..56df6a2ede 100644 --- a/sql/analyzer/optimization_rules_test.go +++ b/sql/analyzer/optimization_rules_test.go @@ -255,6 +255,45 @@ func TestMoveJoinConditionsToFilter(t *testing.T) { ) assertNodesEqualWithDiff(t, expected, result) + + node = plan.NewInnerJoin( + plan.NewResolvedTable(t1, nil, nil), + plan.NewInnerJoin( + plan.NewResolvedTable(t2, nil, nil), + plan.NewResolvedTable(t3, nil, nil), + expression.JoinAnd( + eq(col(0, "t2", "c"), col(0, "t3", "e")), + eq(col(0, "t3", "a"), lit(5)), + ), + ), + expression.JoinAnd( + eq(col(0, "t1", "c"), col(0, "t2", "e")), + ), + ) + + result, err = rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node, nil) + require.NoError(err) + + expected = plan.NewFilter( + expression.JoinAnd( + eq(col(0, "t3", "a"), lit(5)), + ), + plan.NewInnerJoin( + plan.NewResolvedTable(t1, nil, nil), + plan.NewInnerJoin( + plan.NewResolvedTable(t2, nil, nil), + plan.NewResolvedTable(t3, nil, nil), + expression.JoinAnd( + eq(col(0, "t2", "c"), col(0, "t3", "e")), + ), + ), + expression.JoinAnd( + eq(col(0, "t1", "c"), col(0, "t2", "e")), + ), + ), + ) + + assertNodesEqualWithDiff(t, expected, result) } func TestEvalFilter(t *testing.T) {