diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 79b4790dd..ec202d0c0 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -233,10 +233,17 @@ def _push_down_where_filters( for filter_spec in current_pushdown_state.where_filter_specs: filter_spec_semantic_models = self._models_for_spec(filter_spec) + invalid_element_types = [ + element + for element in filter_spec.linkable_elements + if element.element_type not in current_pushdown_state.pushdown_eligible_element_types + ] + if len(filter_spec_semantic_models) != 1 or len(invalid_element_types) > 0: + continue + all_linkable_specs_match = all(spec in source_node_linkable_specs for spec in filter_spec.linkable_specs) - semantic_models_match = ( - len(filter_spec_semantic_models) == 1 and filter_spec_semantic_models[0] == source_semantic_model - ) + # TODO: Handle the case where entities can be defined in multiple models, only one of which need match + semantic_models_match = filter_spec_semantic_models[0] == source_semantic_model if all_linkable_specs_match and semantic_models_match: filters_to_apply.append(filter_spec) else: @@ -277,32 +284,24 @@ def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> Optim def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBranchResult: """Adds where filters from the input node to the current pushdown state. - The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs. For any - filter specs that may be eligible for predicate pushdown this node will add them to the pushdown state. + The WhereConstraintNode carries the filter information in the form of WhereFilterSpecs, which may or may + not be eligible for pushdown. This processor simply propagates them forward so long as where filter + predicate pushdown is still enabled for this branch. + The fact that they have been added at this point does not mean they will be pushed down, as intervening - join nodes might remove them from consideration, so we retain them here as well in order to ensure all - filters are applied as specified. + join nodes might remove them from consideration, so we retain them ensure all filters are applied as specified + within this method. + + TODO: Update to only apply filters that have not been pushed down """ self._log_visit_node_type(node) current_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state if not current_pushdown_state.where_filter_pushdown_enabled: return self._default_handler(node) - where_specs = node.input_where_specs - pushdown_eligible_specs: List[WhereFilterSpec] = [] - for spec in where_specs: - semantic_models = self._models_for_spec(spec) - invalid_element_types = [ - element - for element in spec.linkable_elements - if element.element_type not in current_pushdown_state.pushdown_eligible_element_types - ] - if len(semantic_models) != 1 or len(invalid_element_types) > 0: - continue - pushdown_eligible_specs.append(spec) - - updated_pushdown_state = PredicatePushdownState.with_additional_where_filter_specs( - original_pushdown_state=current_pushdown_state, additional_where_filter_specs=tuple(pushdown_eligible_specs) + updated_pushdown_state = PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=current_pushdown_state, + where_filter_specs=tuple(current_pushdown_state.where_filter_specs) + tuple(node.input_where_specs), ) return self._default_handler(node=node, pushdown_state=updated_pushdown_state) @@ -323,10 +322,8 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo """ self._log_visit_node_type(node) # TODO: move this "remove where filters" logic into PredicatePushdownState - updated_pushdown_state = PredicatePushdownState( - time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint, - where_filter_specs=tuple(), - pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types, + updated_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state, ) return self._default_handler(node=node, pushdown_state=updated_pushdown_state) @@ -342,10 +339,8 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O """ self._log_visit_node_type(node) - base_node_pushdown_state = PredicatePushdownState( - time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint, - where_filter_specs=tuple(), - pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types, + base_node_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state, ) # The conversion metric branch silently removes all filters, so this is a redundant operation. # TODO: Enable pushdown for the conversion metric branch when filters are supported @@ -384,10 +379,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc self._log_visit_node_type(node) left_parent = node.left_node if any(join_description.join_type is SqlJoinType.FULL_OUTER for join_description in node.join_targets): - left_branch_pushdown_state = PredicatePushdownState( - time_range_constraint=self._predicate_pushdown_tracker.last_pushdown_state.time_range_constraint, - where_filter_specs=tuple(), - pushdown_enabled_types=self._predicate_pushdown_tracker.last_pushdown_state.pushdown_enabled_types, + left_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=self._predicate_pushdown_tracker.last_pushdown_state, ) else: left_branch_pushdown_state = self._predicate_pushdown_tracker.last_pushdown_state @@ -399,10 +392,8 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranc base_right_branch_pushdown_state = PredicatePushdownState.without_time_range_constraint( self._predicate_pushdown_tracker.last_pushdown_state ) - outer_join_right_branch_pushdown_state = PredicatePushdownState( - time_range_constraint=None, - where_filter_specs=tuple(), - pushdown_enabled_types=base_right_branch_pushdown_state.pushdown_enabled_types, + outer_join_right_branch_pushdown_state = PredicatePushdownState.without_where_filter_specs( + original_pushdown_state=base_right_branch_pushdown_state ) for join_description in node.join_targets: if ( diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 28054d49b..a2422fdd3 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -221,7 +221,14 @@ def with_time_range_constraint( def without_time_range_constraint( original_pushdown_state: PredicatePushdownState, ) -> PredicatePushdownState: - """Factory method for updating pushdown state to bypass time range constraints.""" + """Factory method for updating pushdown state to bypass time range constraints. + + This eliminates time range constraint pushdown as an option, since the only reason to remove + time range constraint metadata is to turn it off, so we avoid potential issues where + a second ConstrainTimeRange node might update the pushdown state. + + TODO: replace or rename this method. + """ pushdown_enabled_types = original_pushdown_state.pushdown_enabled_types.difference( {PredicateInputType.TIME_RANGE_CONSTRAINT} ) @@ -232,18 +239,30 @@ def without_time_range_constraint( ) @staticmethod - def with_additional_where_filter_specs( - original_pushdown_state: PredicatePushdownState, additional_where_filter_specs: Sequence[WhereFilterSpec] + def without_where_filter_specs( + original_pushdown_state: PredicatePushdownState, + ) -> PredicatePushdownState: + """Factory method for updating pushdown state to remove existing where filter specs. + + This simply blanks out the where filter specs without altering which types of pushdown are available. + """ + return PredicatePushdownState.with_where_filter_specs( + original_pushdown_state=original_pushdown_state, + where_filter_specs=tuple(), + ) + + @staticmethod + def with_where_filter_specs( + original_pushdown_state: PredicatePushdownState, where_filter_specs: Sequence[WhereFilterSpec] ) -> PredicatePushdownState: - """Factory method for adding additional WhereFilterSpecs for pushdown operations. + """Factory method for replacing WhereFilterSpecs in pushdown operations. - This requires that the PushdownState allow for where filters - time range only or disabled states will - raise an exception, and must be checked externally. + This requires that the PushdownState allow for where filters - time range only or disabled states will raise + an exception, and must be checked externally. """ - updated_where_specs = tuple(original_pushdown_state.where_filter_specs) + tuple(additional_where_filter_specs) return PredicatePushdownState( time_range_constraint=original_pushdown_state.time_range_constraint, - where_filter_specs=updated_where_specs, + where_filter_specs=where_filter_specs, pushdown_enabled_types=original_pushdown_state.pushdown_enabled_types, )