Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable predicate pushdown for categorical dimension filters #1227

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240521-202252.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable predicate pushdown for categorical dimensions
time: 2024-05-21T20:22:52.841802-07:00
custom:
Author: tlento
Issue: "1011"
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ metric:
name: instant_booking_fraction_of_max_value
description: |
Average instant booking value as a ratio of overall max booking value.
Tests constrained ratio measure.
Tests constrained ratio measure and predicate pushdown with different filters
on the same measure input.
type: ratio
type_params:
numerator:
Expand Down Expand Up @@ -331,7 +332,8 @@ metric:
name: regional_starting_balance_ratios
description: |
First day account balance ratio of western vs eastern region starting balance ratios,
used to test interaction between semi-additive measures and measure constraints
used to test interaction between semi-additive measures and measure constraints, and
behavior of predicate pushdown when there are multiple filters on the same categorical dimension
type: ratio
type_params:
numerator:
Expand Down
26 changes: 23 additions & 3 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def _build_query_output_node(
)
)

predicate_pushdown_state = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint)
predicate_pushdown_state = PredicatePushdownState(
time_range_constraint=query_spec.time_range_constraint, where_filter_specs=query_level_filter_specs
)

return self._build_metrics_output_node(
metric_specs=tuple(
Expand Down Expand Up @@ -251,6 +253,7 @@ def _build_aggregated_conversion_node(
disabled_pushdown_state = PredicatePushdownState.with_pushdown_disabled()
time_range_only_pushdown_state = PredicatePushdownState(
time_range_constraint=predicate_pushdown_state.time_range_constraint,
where_filter_specs=tuple(),
pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]),
)

Expand Down Expand Up @@ -511,6 +514,11 @@ def _build_base_metric_output_node(
),
descendent_filter_specs=metric_spec.filter_specs,
)
if predicate_pushdown_state.where_filter_pushdown_enabled:
predicate_pushdown_state = PredicatePushdownState.with_additional_where_filter_specs(
original_pushdown_state=predicate_pushdown_state,
additional_where_filter_specs=metric_input_measure_spec.filter_specs,
)

logger.info(
f"For\n{indent(mf_pformat(metric_spec))}"
Expand Down Expand Up @@ -568,6 +576,9 @@ def _build_derived_metric_output_node(

# If metric is offset, we'll apply where constraint after offset to avoid removing values
# unexpectedly. Time constraint will be applied by INNER JOINing to time spine.
# We may consider encapsulating this in pushdown state later, but as of this moment pushdown
# is about post-join to pre-join for dimension access, and relies on the builder to collect
# predicates from query and metric specs and make them available at measure level.
if not metric_spec.has_time_offset:
filter_specs.extend(metric_spec.filter_specs)

Expand Down Expand Up @@ -751,7 +762,9 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da
required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs(
queried_linkable_specs=query_spec.linkable_specs, filter_specs=query_level_filter_specs
)
predicate_pushdown_state = PredicatePushdownState(time_range_constraint=query_spec.time_range_constraint)
predicate_pushdown_state = PredicatePushdownState(
time_range_constraint=query_spec.time_range_constraint, where_filter_specs=query_level_filter_specs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this get narrowed down to only categorical dimensions? 🤔

)
dataflow_recipe = self._find_dataflow_recipe(
linkable_spec_set=required_linkable_specs, predicate_pushdown_state=predicate_pushdown_state
)
Expand Down Expand Up @@ -954,7 +967,14 @@ def _find_dataflow_recipe(
node_data_set_resolver=self._node_data_set_resolver,
)

if predicate_pushdown_state.has_pushdown_potential:
if predicate_pushdown_state.has_pushdown_potential and default_join_type is not SqlJoinType.FULL_OUTER:
# TODO: encapsulate join type and distinct values state and eventually move this to a DataflowPlanOptimizer
# This works today because all of our subsequent join configuration operations preserve the join type
# as-is, or else switch it to a CROSS JOIN or INNER JOIN type, both of which are safe for predicate
# pushdown. However, there is currently no way to enforce that invariant, so we will need to move
# to a model where we evaluate the join nodes themselves and decide on whether or not to push down
# the predicate. This will be much more straightforward once we finish encapsulating our existing
# time range constraint pushdown controls into this mechanism.
candidate_nodes_for_left_side_of_join = list(
node_processor.apply_matching_filter_predicates(
source_nodes=candidate_nodes_for_left_side_of_join,
Expand Down
147 changes: 131 additions & 16 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import dataclasses
import logging
from enum import Enum
from typing import FrozenSet, List, Optional, Sequence, Set
from typing import Dict, FrozenSet, List, Optional, Sequence, Set

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.references import EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.references import EntityReference, SemanticModelReference, TimeDimensionReference
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.mf_logging.pretty_print import mf_pformat
from metricflow_semantics.model.semantics.linkable_element import LinkableElementType
from metricflow_semantics.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, LinklessEntitySpec
from metricflow_semantics.specs.spec_classes import LinkableInstanceSpec, LinklessEntitySpec, WhereFilterSpec
from metricflow_semantics.specs.spec_set import group_specs_by_type
from metricflow_semantics.specs.spec_set_transforms import ToElementNameSet
from metricflow_semantics.sql.sql_join_type import SqlJoinType
Expand All @@ -25,6 +26,7 @@
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.validation.dataflow_join_validator import JoinDataflowOutputValidator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,7 +97,10 @@ class PredicatePushdownState:
"""

time_range_constraint: Optional[TimeRangeConstraint]
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT])
where_filter_specs: Sequence[WhereFilterSpec]
pushdown_enabled_types: FrozenSet[PredicateInputType] = frozenset(
[PredicateInputType.TIME_RANGE_CONSTRAINT, PredicateInputType.CATEGORICAL_DIMENSION]
)

def __post_init__(self) -> None:
"""Validation to ensure pushdown states are configured correctly.
Expand All @@ -107,13 +112,12 @@ def __post_init__(self) -> None:
invalid_types: Set[PredicateInputType] = set()

for input_type in self.pushdown_enabled_types:
if (
if input_type is PredicateInputType.ENTITY or input_type is PredicateInputType.TIME_DIMENSION:
invalid_types.add(input_type)
elif (
input_type is PredicateInputType.CATEGORICAL_DIMENSION
or input_type is PredicateInputType.ENTITY
or input_type is PredicateInputType.TIME_DIMENSION
or input_type is PredicateInputType.TIME_RANGE_CONSTRAINT
):
invalid_types.add(input_type)
elif input_type is PredicateInputType.TIME_RANGE_CONSTRAINT:
continue
else:
assert_values_exhausted(input_type)
Expand All @@ -125,23 +129,24 @@ def __post_init__(self) -> None:
f"for {self.pushdown_enabled_types}, which includes the following invalid types: {invalid_types}."
)

# TODO: Include where filter specs when they are added to this class
time_range_constraint_is_valid = (
self.time_range_constraint is None
or PredicateInputType.TIME_RANGE_CONSTRAINT in self.pushdown_enabled_types
)
assert time_range_constraint_is_valid, (
where_filter_specs_are_valid = len(self.where_filter_specs) == 0 or self.where_filter_pushdown_enabled
assert time_range_constraint_is_valid and where_filter_specs_are_valid, (
"Invalid pushdown state configuration! Disabled pushdown state objects cannot have properties "
"set that may lead to improper access and use in other contexts, as that can lead to unintended "
"filtering operations in cases where these properties are accessed without appropriate checks against "
"pushdown configuration. The following properties should all have None values:\n"
f"time_range_constraint: {self.time_range_constraint}"
"pushdown configuration. The following properties should be None or empty:\n"
f"time_range_constraint: {self.time_range_constraint}\n"
f"where_filter_specs: {self.where_filter_specs}"
)

@property
def has_pushdown_potential(self) -> bool:
"""Returns whether or not pushdown is enabled for a type with predicate candidates in place."""
return self.has_time_range_constraint_to_push_down
return self.has_time_range_constraint_to_push_down or self.has_where_filters_to_push_down

@property
def has_time_range_constraint_to_push_down(self) -> bool:
Expand All @@ -156,6 +161,44 @@ def has_time_range_constraint_to_push_down(self) -> bool:
and self.time_range_constraint is not None
)

@property
def has_where_filters_to_push_down(self) -> bool:
"""Convenience accessor for checking if there are any where filters to push down."""
return self.where_filter_pushdown_enabled and len(self.where_filter_specs) > 0

@property
def where_filter_pushdown_enabled(self) -> bool:
"""Indicates whether or not pushdown is enabled for where filters."""
return (
PredicateInputType.CATEGORICAL_DIMENSION in self.pushdown_enabled_types
or PredicateInputType.ENTITY in self.pushdown_enabled_types
or PredicateInputType.TIME_DIMENSION in self.pushdown_enabled_types
)

@property
def pushdown_eligible_element_types(self) -> FrozenSet[LinkableElementType]:
"""Set of linkable element types eligible for predicate pushdown.

This converts from enabled PushdownInputTypes for checking if linkable elements in where filter specs are
eligible for pushdown.
"""
eligible_types: List[LinkableElementType] = []
for enabled_type in self.pushdown_enabled_types:
if enabled_type is PredicateInputType.TIME_RANGE_CONSTRAINT:
pass
elif enabled_type is PredicateInputType.CATEGORICAL_DIMENSION:
eligible_types.append(LinkableElementType.DIMENSION)
elif enabled_type is PredicateInputType.TIME_DIMENSION or enabled_type is PredicateInputType.ENTITY:
# TODO: Remove as support for time dimensions and entities becomes available
raise NotImplementedError(
"Predicate pushdown is not currently supported for where filter predicates with time dimension or "
f"entity references, but this pushdown state is enabled for {enabled_type}."
)
else:
assert_values_exhausted(enabled_type)

return frozenset(eligible_types)

@staticmethod
def with_time_range_constraint(
original_pushdown_state: PredicatePushdownState, time_range_constraint: TimeRangeConstraint
Expand All @@ -169,7 +212,9 @@ def with_time_range_constraint(
{PredicateInputType.TIME_RANGE_CONSTRAINT}
)
return PredicatePushdownState(
time_range_constraint=time_range_constraint, pushdown_enabled_types=pushdown_enabled_types
time_range_constraint=time_range_constraint,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
)

@staticmethod
Expand All @@ -180,7 +225,27 @@ def without_time_range_constraint(
pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference(
{PredicateInputType.TIME_RANGE_CONSTRAINT}
)
return PredicatePushdownState(time_range_constraint=None, pushdown_enabled_types=pushdown_enabled_types)
return PredicatePushdownState(
time_range_constraint=None,
pushdown_enabled_types=pushdown_enabled_types,
where_filter_specs=original_pushdown_state.where_filter_specs,
)

@staticmethod
def with_additional_where_filter_specs(
original_pushdown_state: PredicatePushdownState, additional_where_filter_specs: Sequence[WhereFilterSpec]
) -> PredicatePushdownState:
"""Factory method for adding additional WhereFilterSpecs for pushdown operations.

This requires that the PushdownState allow for where filters - time range only or disabled states will
raise an exception, and must be checked externally.
"""
updated_where_specs = tuple(original_pushdown_state.where_filter_specs) + tuple(additional_where_filter_specs)
return PredicatePushdownState(
time_range_constraint=original_pushdown_state.time_range_constraint,
where_filter_specs=updated_where_specs,
pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types,
)

@staticmethod
def with_pushdown_disabled() -> PredicatePushdownState:
Expand All @@ -194,6 +259,7 @@ def with_pushdown_disabled() -> PredicatePushdownState:
return PredicatePushdownState(
time_range_constraint=None,
pushdown_enabled_types=frozenset(),
where_filter_specs=tuple(),
)


Expand Down Expand Up @@ -240,6 +306,13 @@ def apply_matching_filter_predicates(
time_range_constraint=predicate_pushdown_state.time_range_constraint,
)

if predicate_pushdown_state.has_where_filters_to_push_down:
source_nodes = self._add_where_constraint(
source_nodes=source_nodes,
where_filter_specs=predicate_pushdown_state.where_filter_specs,
enabled_element_types=predicate_pushdown_state.pushdown_eligible_element_types,
)

return source_nodes

def _add_time_range_constraint(
Expand Down Expand Up @@ -272,6 +345,48 @@ def _add_time_range_constraint(
processed_nodes.append(source_node)
return processed_nodes

def _add_where_constraint(
self,
source_nodes: Sequence[DataflowPlanNode],
where_filter_specs: Sequence[WhereFilterSpec],
enabled_element_types: FrozenSet[LinkableElementType],
) -> Sequence[DataflowPlanNode]:
"""Processes where filter specs and evaluates their fitness for pushdown against the provided node set."""
eligible_filter_specs_by_model: Dict[SemanticModelReference, Sequence[WhereFilterSpec]] = {}
for spec in where_filter_specs:
semantic_models = set(element.semantic_model_origin for element in spec.linkable_elements)
invalid_element_types = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this is the logic I was looking for 👍

element for element in spec.linkable_elements if element.element_type not in enabled_element_types
]
if len(semantic_models) == 1 and len(invalid_element_types) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we only push down filters if ALL the filtered elements are eligible element types. Is that because we don't know if this is an AND or an OR filter at this point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, we cannot push down a filter with any invalid element types. The AND vs OR nature of things isn't relevant, it's because we don't have a way to handle those element types. Right now that's just because we haven't implemented handling, but in future it could be due to a given query being too difficult to manage for a given element type.

For example, agg time dimension filters against a mixture of cumulative and derived offset metric inputs could get very tricky. In those cases we may not be able to push down a where filter with a time dimension.

My expectation is that this will be more refined than clobbering anything that has a time dimension of any kind in it, but for now this definitely works and we can use more finesse later.

model = semantic_models.pop()
eligible_filter_specs_by_model[model] = tuple(eligible_filter_specs_by_model.get(model, tuple())) + (
spec,
)

filtered_nodes: List[DataflowPlanNode] = []
for source_node in source_nodes:
node_semantic_models = tuple(source_node.as_plan().source_semantic_models)
if len(node_semantic_models) == 1 and node_semantic_models[0] in eligible_filter_specs_by_model:
eligible_filter_specs = eligible_filter_specs_by_model[node_semantic_models[0]]
source_node_specs = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
matching_filter_specs = [
filter_spec
for filter_spec in eligible_filter_specs
if all([spec in source_node_specs.linkable_specs for spec in filter_spec.linkable_specs])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a time where this won't be true, since we get the semantic model from the linkable element above? Or is this just an extra safety check in case something gets misconfigured?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this time it's a safeguard against something weird happening where a given source node isn't configured correctly. However, I expect this filter to be relevant for entities, since they may be defined in multiple semantic models and we need to be able to explicitly allow or disallow pushdown in those cases. If we ever add a pre-joined source node, for example, we might encounter a scenario where the entity and dimension come from different semantic models and then we couldn't push down past this point (and maybe shouldn't push down to this point, either).

]
if len(matching_filter_specs) == 0:
filtered_nodes.append(source_node)
else:
where_constraint = WhereFilterSpec.merge_iterable(matching_filter_specs)
filtered_nodes.append(
WhereConstraintNode(parent_node=source_node, where_constraint=where_constraint)
)
else:
filtered_nodes.append(source_node)

return filtered_nodes

def _node_contains_entity(
self,
node: DataflowPlanNode,
Expand Down
11 changes: 7 additions & 4 deletions tests_metricflow/dataflow/builder/test_predicate_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
@pytest.fixture
def fully_enabled_pushdown_state() -> PredicatePushdownState:
"""Tests a valid configuration with all predicate properties set and pushdown fully enabled."""
params = PredicatePushdownState(
time_range_constraint=TimeRangeConstraint.all_time(),
)
params = PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), where_filter_specs=tuple())
return params


Expand All @@ -20,6 +18,7 @@ def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: Predic
time_range_only_state = PredicatePushdownState(
time_range_constraint=TimeRangeConstraint.all_time(),
pushdown_enabled_types=frozenset([PredicateInputType.TIME_RANGE_CONSTRAINT]),
where_filter_specs=tuple(),
)

enabled_states = {
Expand All @@ -39,4 +38,8 @@ def test_time_range_pushdown_enabled_states(fully_enabled_pushdown_state: Predic
def test_invalid_disabled_pushdown_state() -> None:
"""Tests checks for invalid param configuration on disabled pushdown parameters."""
with pytest.raises(AssertionError, match="Disabled pushdown state objects cannot have properties set"):
PredicatePushdownState(time_range_constraint=TimeRangeConstraint.all_time(), pushdown_enabled_types=frozenset())
PredicatePushdownState(
time_range_constraint=TimeRangeConstraint.all_time(),
pushdown_enabled_types=frozenset(),
where_filter_specs=tuple(),
)
Loading
Loading