From 20e31a04ae83e5ff81ac12217c772766810cf43c Mon Sep 17 00:00:00 2001 From: "tom.bonfert" Date: Mon, 11 May 2026 20:49:10 +0200 Subject: [PATCH 1/2] Enhance KeyValueStoreSolver to support MetricExpression filters and update documentation * Added support for MetricExpression filters in KeyValueStoreSolver's filter_container_metrics method. * Updated the report validation to reflect that KeyValueStoreSolver now supports both tag and metric filters. * Modified tests to ensure proper functionality of new metric filtering capabilities and updated container_metrics column mappings. --- .../query/solvers/key_value_store_solver.py | 51 +++- src/mda_reporting/core/report.py | 9 +- .../integration/kvs_solver_test.py | 255 ++++++++++++++++++ .../solvers/custom_column_mapping_test.py | 8 +- .../solvers/key_value_store_alias_test.py | 10 + .../solvers/key_value_store_solver_test.py | 109 +++++++- .../integration/simple_report_test.py | 3 + 7 files changed, 423 insertions(+), 22 deletions(-) create mode 100644 tests/mda_query_engine/integration/kvs_solver_test.py diff --git a/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py b/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py index 96bb736..9b500b9 100644 --- a/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py +++ b/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py @@ -5,6 +5,7 @@ import pyspark.sql.functions as F from pyspark.sql import DataFrame, Window +from mda_query_engine.analyze.metadata.metric_expression import MetricExpression from mda_query_engine.analyze.metadata.tag_expression import TagExpression from .basic_narrow_solver import BasicNarrowSolver @@ -107,8 +108,14 @@ def filter_container_metrics( self, spark, query, container_df, pre_filtered_containers_df=None ) -> DataFrame: """ - Filter containers by joining container_metrics with the tag-filtered - container DataFrame. + Filter container_metrics and join with tag-filtered container IDs. + + Reads the ``container_metrics`` table, applies the per-table + ``column_name_mapping`` to rename physical columns to internal names, + applies the top-level ``project_id`` filter, any per-table + ``container_metrics.filters``, and any ``MetricExpression`` filters + extracted from the query. Finally, inner-joins the result with the + tag-filtered container DataFrame. Parameters ---------- @@ -117,26 +124,48 @@ def filter_container_metrics( query : QueryBuilder Query object containing filters and db info. container_df : pyspark.sql.DataFrame - DataFrame containing tag-filtered container IDs. + DataFrame containing tag-filtered container IDs (output of + :meth:`filter_container_tags`). pre_filtered_containers_df : pyspark.sql.DataFrame, optional - DataFrame containing pre-filtered container information. + Pre-filtered container_metrics DataFrame. When provided, it + replaces the read from ``query.db.container_metrics``. Returns ------- pyspark.sql.DataFrame - DataFrame containing filtered container metrics. + Filtered container metrics with all original columns preserved. + Deduplicated by ``container_id``. """ container_id_col = self.config.container_id_col + metric_filters = [ + filt for filt in query.filters if isinstance(filt, MetricExpression) + ] + if pre_filtered_containers_df is not None: - container_metrics = pre_filtered_containers_df + metrics = pre_filtered_containers_df else: - container_metrics = query.db.container_metrics(self.spark) - container_metrics = self._apply_column_mapping( - container_metrics, self.config.container_metrics.column_name_mapping + metrics = query.db.container_metrics(self.spark) + + metrics = self._apply_column_mapping( + metrics, self.config.container_metrics.column_name_mapping ) - return container_metrics.join( - container_df, how="inner", on=container_id_col + + if self.config.project_id is not None: + metrics = metrics.where( + F.col(self.config.project_id_col) == self.config.project_id + ) + + for col_name, value in self.config.container_metrics.filters.items(): + metrics = metrics.where(F.col(col_name) == value) + + if len(metric_filters) > 0: + metrics = metrics.where(self._build_expr(metric_filters)) + + return metrics.join( + F.broadcast(container_df.select(container_id_col).distinct()), + on=container_id_col, + how="inner", ).dropDuplicates([container_id_col]) def filter_aliased_channel_metrics( diff --git a/src/mda_reporting/core/report.py b/src/mda_reporting/core/report.py index 63df7d7..576a1fd 100644 --- a/src/mda_reporting/core/report.py +++ b/src/mda_reporting/core/report.py @@ -225,7 +225,7 @@ def create_query_builder(db: MeasurementDB, config: MdaConfig) -> QueryBuilder: Validates solver/filter compatibility before applying filters: - BasicNarrowSolver supports metric filters only (rejects tag filters). - - KeyValueStoreSolver supports tag filters only (rejects metric filters). + - KeyValueStoreSolver supports both tag and metric filters. - DeltaSolver supports both tag and metric filters. Parameters @@ -249,7 +249,6 @@ def create_query_builder(db: MeasurementDB, config: MdaConfig) -> QueryBuilder: if config.container_filters is not None: has_tag_filters = len(config.container_filters.tag_filters) > 0 - has_metric_filters = len(config.container_filters.metric_filters) > 0 if has_tag_filters and config.query_engine.solver == Solvers.BASIC_NARROW_SOLVER: raise ValueError( @@ -257,12 +256,6 @@ def create_query_builder(db: MeasurementDB, config: MdaConfig) -> QueryBuilder: "Use DeltaSolver or KeyValueStoreSolver for tag-based filtering." ) - if has_metric_filters and config.query_engine.solver == Solvers.KEY_VALUE_STORE_SOLVER: - raise ValueError( - "Metric filters are not supported with KeyValueStoreSolver. " - "Use DeltaSolver or BasicNarrowSolver for metric-based filtering." - ) - tag_filter_expr = ReportEntityUtil.generate_tag_filters( query, config.container_filters.tag_filters ) diff --git a/tests/mda_query_engine/integration/kvs_solver_test.py b/tests/mda_query_engine/integration/kvs_solver_test.py new file mode 100644 index 0000000..442fd4c --- /dev/null +++ b/tests/mda_query_engine/integration/kvs_solver_test.py @@ -0,0 +1,255 @@ +# pylint: disable=missing-function-docstring +"""End-to-end integration tests for the KeyValueStoreSolver. + +Exercises the full 6-stage pipeline (filter_container_tags → +filter_container_metrics → filter_channel_tags → filter_channel_metrics → +solve) by calling ``QueryBuilder.solve`` with a ``KeyValueStoreSolver`` +instance against the existing ``key_value_store_db`` / +``key_value_store_alias_db`` fixtures. +""" + +import pyspark.sql.functions as F +import pytest +from pyspark.sql import SparkSession + +from mda_query_engine.analyze.metadata.metric_expression import MetricSelector +from mda_query_engine.analyze.metadata.tag_expression import TagSelector +from mda_query_engine.analyze.query.aggregations.stats_aggregator import StatsAggregator +from mda_query_engine.analyze.query.solvers.key_value_store_solver import ( + KeyValueStoreSolver, +) +from mda_query_engine.analyze.query.solvers.solver_config import ( + SolverConfig, + TableConfig, +) +from mda_query_engine.measurement_db import MeasurementDB + + +def _kvs_cfg( + project_id: str = "SAMPLE_PROJECT", + container_tags: TableConfig | None = None, + container_metrics: TableConfig | None = None, + channel_mapping: TableConfig | None = None, +) -> SolverConfig: + """Build a SolverConfig wired up for the KVS test data. + + The KVS fixture reuses the basic_narrow_csv ``container_metrics`` table, + which uses ``project`` instead of ``project_id``. The narrow EAV + ``container_tags`` table uses ``element_id`` instead of ``key``. + """ + return SolverConfig( + project_id=project_id, + container_tags=container_tags + or TableConfig(column_name_mapping={"element_id": "key"}), + container_metrics=container_metrics + or TableConfig(column_name_mapping={"project": "project_id"}), + channel_mapping=channel_mapping or TableConfig(), + ) + + +class TestKeyValueStoreSolverIntegration: + """End-to-end pipeline tests against the key_value_store_db fixture.""" + + def test_solve_no_filters_returns_all_containers( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """Without any filter the solver should emit one result row per container.""" + solver = KeyValueStoreSolver(spark, config=_kvs_cfg()) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + + result = query.select(eng_rpm.mean().alias("rpm_mean")).solve( + spark=spark, solver=solver + ) + + container_ids = {row.container_id for row in result.collect()} + assert container_ids == {1, 2, 3} + + def test_solve_with_tag_expression_filter( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """A matching TagExpression keeps all containers; a non-matching one drops all.""" + solver = KeyValueStoreSolver(spark, config=_kvs_cfg()) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + query.where(TagSelector("brand") == "Seat") + + result = query.select(eng_rpm.mean().alias("rpm_mean")).solve( + spark=spark, solver=solver + ) + assert {row.container_id for row in result.collect()} == {1, 2, 3} + + query2 = key_value_store_db.query + query2.where(TagSelector("brand") == "VW") + result2 = query2.select( + query2.channel(channel_name="Engine RPM").mean().alias("rpm_mean") + ).solve(spark=spark, solver=solver) + assert result2.count() == 0 + + def test_solve_with_metric_expression_filter( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """MetricExpression on container_metrics should narrow the solve result.""" + solver = KeyValueStoreSolver(spark, config=_kvs_cfg()) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + query.where(MetricSelector("brand") == "Seat") + + result = query.select(eng_rpm.mean().alias("rpm_mean")).solve( + spark=spark, solver=solver + ) + assert {row.container_id for row in result.collect()} == {1, 2, 3} + + query2 = key_value_store_db.query + query2.where(MetricSelector("brand") == "VW") + result2 = query2.select( + query2.channel(channel_name="Engine RPM").mean().alias("rpm_mean") + ).solve(spark=spark, solver=solver) + assert result2.count() == 0 + + def test_solve_with_combined_tag_and_metric_filters( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """TagExpression (stage 1) + MetricExpression (stage 2) should both be applied.""" + solver = KeyValueStoreSolver(spark, config=_kvs_cfg()) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + query.where(TagSelector("brand") == "Seat") + query.where(MetricSelector("model") == "Leon") + + result = query.select(eng_rpm.mean().alias("rpm_mean")).solve( + spark=spark, solver=solver + ) + assert {row.container_id for row in result.collect()} == {1, 2, 3} + + # Tag matches, metric does not -> zero rows from stage 2 + query2 = key_value_store_db.query + query2.where(TagSelector("brand") == "Seat") + query2.where(MetricSelector("model") == "Ibiza") + result2 = query2.select( + query2.channel(channel_name="Engine RPM").mean().alias("rpm_mean") + ).solve(spark=spark, solver=solver) + assert result2.count() == 0 + + def test_solve_with_event_expression_and_stats( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """Event-gated stats aggregation should produce well-formed results per container.""" + solver = KeyValueStoreSolver(spark, config=_kvs_cfg()) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + veh_speed = query.channel(channel_name="Vehicle Speed Sensor") + high_speed_event = veh_speed > 50 + + stats_agg = StatsAggregator( + input_expressions=[eng_rpm], + statistics=["min", "max", "mean"], + event_expression=high_speed_event, + ) + + result = query.select(stats_agg.alias("rpm_when_fast")).solve( + spark=spark, solver=solver + ) + + assert result.count() == 3 + for row in result.collect(): + event_timestamps, numeric_values, _ = row["rpm_when_fast"] + assert len(numeric_values) == 1 + for event_stats in numeric_values[0]: + assert {"min", "max", "mean"}.issubset(event_stats.keys()) + if event_stats["min"] is not None: + assert event_stats["min"] <= event_stats["mean"] <= event_stats["max"] + for ts in event_timestamps: + assert ts[0] <= ts[1] + + def test_solve_non_existent_project_returns_empty( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """A project_id with no containers should yield zero solve rows.""" + solver = KeyValueStoreSolver(spark, config=_kvs_cfg("NON_EXISTENT_PROJECT")) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + + result = query.select(eng_rpm.mean().alias("rpm_mean")).solve( + spark=spark, solver=solver + ) + assert result.count() == 0 + + def test_solve_with_matching_parent_id_filter( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """``container_tags.filters`` should narrow containers via parent_id.""" + cfg = _kvs_cfg( + container_tags=TableConfig( + column_name_mapping={"element_id": "key"}, + filters={"parent_id": "container_concept"}, + ), + ) + solver = KeyValueStoreSolver(spark, config=cfg) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + + result = query.select(eng_rpm.mean().alias("rpm_mean")).solve( + spark=spark, solver=solver + ) + assert {row.container_id for row in result.collect()} == {1, 2, 3} + + def test_solve_with_non_matching_parent_id_filter( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """A parent_id that no row has should drop everything before stage 2.""" + cfg = _kvs_cfg( + container_tags=TableConfig( + column_name_mapping={"element_id": "key"}, + filters={"parent_id": "no_such_concept"}, + ), + ) + solver = KeyValueStoreSolver(spark, config=cfg) + query = key_value_store_db.query + eng_rpm = query.channel(channel_name="Engine RPM") + + result = query.select(eng_rpm.mean().alias("rpm_mean")).solve( + spark=spark, solver=solver + ) + assert result.count() == 0 + + def test_solve_with_pre_filtered_containers( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """``pre_filtered_containers_df`` restricts the solve to its container set.""" + solver = KeyValueStoreSolver(spark, config=_kvs_cfg()) + query = key_value_store_db.query + pre = query.db.container_metrics(spark).where(F.col("container_id") == 1) + + result = query.select( + query.channel(channel_name="Engine RPM").mean().alias("rpm_mean") + ).solve(spark=spark, solver=solver, pre_filtered_containers_df=pre) + + rows = result.collect() + assert {row.container_id for row in rows} == {1} + + +class TestKeyValueStoreSolverAliasIntegration: + """End-to-end pipeline tests against the alias-enabled KVS fixture.""" + + def test_solve_with_aliased_channel( + self, spark: SparkSession, key_value_store_alias_db: MeasurementDB + ): + """Aliased channel selection should resolve via channel_mapping and produce results.""" + solver = KeyValueStoreSolver( + spark, + config=_kvs_cfg( + channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), + ), + ) + query = key_value_store_alias_db.query + engine_speed = query.channel_with_alias(channel_alias="engine_speed") + + result = query.select(engine_speed.mean().alias("engine_speed_mean")).solve( + spark=spark, solver=solver + ) + + rows = {row.container_id: row for row in result.collect()} + assert set(rows.keys()) == {1, 2, 3} + for row in rows.values(): + assert row["engine_speed_mean"] is not None diff --git a/tests/mda_query_engine/unit/analyze/query/solvers/custom_column_mapping_test.py b/tests/mda_query_engine/unit/analyze/query/solvers/custom_column_mapping_test.py index 33eb6ba..9624265 100644 --- a/tests/mda_query_engine/unit/analyze/query/solvers/custom_column_mapping_test.py +++ b/tests/mda_query_engine/unit/analyze/query/solvers/custom_column_mapping_test.py @@ -388,7 +388,9 @@ def _cfg() -> SolverConfig: container_tags=TableConfig( column_name_mapping={"entity_id": "container_id", "element_id": "key"}, ), - container_metrics=TableConfig(column_name_mapping={"meas_id": "container_id"}), + container_metrics=TableConfig( + column_name_mapping={"meas_id": "container_id", "project": "project_id"}, + ), channel_metrics=TableConfig(column_name_mapping={"meas_id": "container_id"}), channels=TableConfig(column_name_mapping={"meas_id": "container_id"}), ) @@ -479,7 +481,9 @@ def _make_cfg(project_id: str = "SAMPLE_PROJECT") -> SolverConfig: "element_id": "key", }, ), - container_metrics=TableConfig(column_name_mapping={"run_id": "container_id"}), + container_metrics=TableConfig( + column_name_mapping={"run_id": "container_id", "project": "project_id"}, + ), channel_metrics=TableConfig(column_name_mapping={"run_id": "container_id"}), channels=TableConfig( column_name_mapping={"run_id": "container_id", "attr_val": "value"}, diff --git a/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_alias_test.py b/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_alias_test.py index dfc654e..21ab7b4 100644 --- a/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_alias_test.py +++ b/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_alias_test.py @@ -30,6 +30,7 @@ def test_no_aliased_selections_returns_empty( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -50,6 +51,7 @@ def test_alias_resolves_to_correct_channels( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -77,6 +79,7 @@ def test_alias_scoped_by_project_id( spark, config=SolverConfig( project_id="NON_EXISTENT_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -96,6 +99,7 @@ def test_alias_scoped_by_toolbox_id( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "non_existent_toolbox"}), ), ) @@ -115,6 +119,7 @@ def test_selector_id_consistent_for_same_expression( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -134,6 +139,7 @@ def test_multiple_aliases(self, spark: SparkSession, key_value_store_alias_db: M spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -158,6 +164,7 @@ def test_solve_with_alias_only( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -177,6 +184,7 @@ def test_solve_with_mixed_direct_and_alias( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -200,6 +208,7 @@ def test_solve_deduplication( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) @@ -231,6 +240,7 @@ def test_alias_returns_same_channel_data_as_direct_engine_rpm( spark, config=SolverConfig( project_id="SAMPLE_PROJECT", + container_metrics=TableConfig(column_name_mapping={"project": "project_id"}), channel_mapping=TableConfig(filters={"toolbox_id": "container_concept"}), ), ) diff --git a/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_solver_test.py b/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_solver_test.py index 581837d..8ad44c8 100644 --- a/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_solver_test.py +++ b/tests/mda_query_engine/unit/analyze/query/solvers/key_value_store_solver_test.py @@ -14,6 +14,7 @@ - MetricExpression internal API tests """ +import pyspark.sql.functions as F import pytest from pyspark.sql import SparkSession @@ -36,7 +37,16 @@ def _default_cfg(project_id: str = "SAMPLE_PROJECT", **table_overrides) -> Solve "container_tags", TableConfig(column_name_mapping={"element_id": "key"}), ) - return SolverConfig(project_id=project_id, container_tags=container_tags, **table_overrides) + container_metrics = table_overrides.pop( + "container_metrics", + TableConfig(column_name_mapping={"project": "project_id"}), + ) + return SolverConfig( + project_id=project_id, + container_tags=container_tags, + container_metrics=container_metrics, + **table_overrides, + ) class TestKeyValueStoreSolverFilterContainerTags: @@ -198,6 +208,103 @@ def test_no_filter_returns_all_matching_metrics( container_ids = {row.container_id for row in result.collect()} assert len(container_ids) > 0 + def test_metric_expression_narrows_results( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """A MetricExpression filter on a container_metrics column should restrict results.""" + solver = KeyValueStoreSolver(spark, config=_default_cfg()) + query = key_value_store_db.query + query.where(MetricSelector("brand") == "Seat") + tags_df = solver.filter_container_tags(spark, query) + result = solver.filter_container_metrics(spark, query, tags_df) + container_ids = {row.container_id for row in result.collect()} + assert container_ids == {1, 2, 3} + + def test_non_matching_metric_expression_returns_empty( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """A MetricExpression that matches no container_metrics rows yields zero.""" + solver = KeyValueStoreSolver(spark, config=_default_cfg()) + query = key_value_store_db.query + query.where(MetricSelector("brand") == "VW") + tags_df = solver.filter_container_tags(spark, query) + result = solver.filter_container_metrics(spark, query, tags_df) + assert result.count() == 0 + + def test_config_container_metrics_filter_applied( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """``config.container_metrics.filters`` should be applied to container_metrics.""" + solver = KeyValueStoreSolver( + spark, + config=_default_cfg( + container_metrics=TableConfig( + column_name_mapping={"project": "project_id"}, + filters={"brand": "Seat"}, + ), + ), + ) + query = key_value_store_db.query + tags_df = solver.filter_container_tags(spark, query) + result = solver.filter_container_metrics(spark, query, tags_df) + container_ids = {row.container_id for row in result.collect()} + assert container_ids == {1, 2, 3} + + def test_config_container_metrics_filter_excludes_all( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """A non-matching ``container_metrics.filters`` value yields zero rows.""" + solver = KeyValueStoreSolver( + spark, + config=_default_cfg( + container_metrics=TableConfig( + column_name_mapping={"project": "project_id"}, + filters={"brand": "VW"}, + ), + ), + ) + query = key_value_store_db.query + tags_df = solver.filter_container_tags(spark, query) + result = solver.filter_container_metrics(spark, query, tags_df) + assert result.count() == 0 + + def test_non_existent_project_returns_empty( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """A non-existent project_id should yield zero container_metrics rows.""" + solver = KeyValueStoreSolver(spark, config=_default_cfg("NON_EXISTENT_PROJECT")) + query = key_value_store_db.query + tags_df = solver.filter_container_tags(spark, query) + result = solver.filter_container_metrics(spark, query, tags_df) + assert result.count() == 0 + + def test_pre_filtered_containers_df_short_circuits_read( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """``pre_filtered_containers_df`` should replace the table read.""" + solver = KeyValueStoreSolver(spark, config=_default_cfg()) + query = key_value_store_db.query + full = query.db.container_metrics(spark) + pre = full.where(F.col("container_id") == 1) + tags_df = solver.filter_container_tags(spark, query) + result = solver.filter_container_metrics( + spark, query, tags_df, pre_filtered_containers_df=pre + ) + container_ids = {row.container_id for row in result.collect()} + assert container_ids == {1} + + def test_all_container_metrics_columns_preserved( + self, spark: SparkSession, key_value_store_db: MeasurementDB + ): + """Result must keep all container_metrics columns (e.g. start_ts/stop_ts).""" + solver = KeyValueStoreSolver(spark, config=_default_cfg()) + query = key_value_store_db.query + tags_df = solver.filter_container_tags(spark, query) + result = solver.filter_container_metrics(spark, query, tags_df) + cols = set(result.columns) + # container_metrics CSV has start_ts/stop_ts plus dimensional columns + assert {"container_id", "start_ts", "stop_ts", "brand", "model"}.issubset(cols) + class TestKeyValueStoreSolverEmptySelector: """Tests for empty and edge-case TagSelector values in the KeyValueStoreSolver.""" diff --git a/tests/mda_reporting/integration/simple_report_test.py b/tests/mda_reporting/integration/simple_report_test.py index 9db17e4..8e09058 100644 --- a/tests/mda_reporting/integration/simple_report_test.py +++ b/tests/mda_reporting/integration/simple_report_test.py @@ -772,6 +772,9 @@ def test_simple_report_key_value_store(spark, key_value_store_db): solver_config=SolverConfig( project_id="SAMPLE_PROJECT", container_tags=TableConfig(column_name_mapping={"element_id": "key"}), + container_metrics=TableConfig( + column_name_mapping={"project": "project_id"}, + ), ), ), measurement_dimensions=[MeasurementDimensions.CONTAINER_ID], From 581018f9ebcc587c1b0dd8f2f3cdc7dc9ae29c3e Mon Sep 17 00:00:00 2001 From: "tom.bonfert" Date: Mon, 11 May 2026 23:12:52 +0200 Subject: [PATCH 2/2] Refactor KeyValueStoreSolver to remove redundant distinct call in broadcast join --- .../analyze/query/solvers/key_value_store_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py b/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py index 9b500b9..db267a0 100644 --- a/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py +++ b/src/mda_query_engine/analyze/query/solvers/key_value_store_solver.py @@ -163,7 +163,7 @@ def filter_container_metrics( metrics = metrics.where(self._build_expr(metric_filters)) return metrics.join( - F.broadcast(container_df.select(container_id_col).distinct()), + F.broadcast(container_df.select(container_id_col)), on=container_id_col, how="inner", ).dropDuplicates([container_id_col])