From 3950abd1330bd7542d3c1af0a5533857b2c07c03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20B=C3=A5venstrand?= Date: Tue, 21 May 2024 15:04:35 +0200 Subject: [PATCH] feat(transformer): Update `ExpressionTransformer` to use `TypedDict` instead of tuples. --- .../transform/expression_transformer.py | 41 +++++++++++++------ .../transform/test_expression_transformer.py | 14 +++---- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/mleko/dataset/transform/expression_transformer.py b/mleko/dataset/transform/expression_transformer.py index 9d583fbc..16ad29f1 100644 --- a/mleko/dataset/transform/expression_transformer.py +++ b/mleko/dataset/transform/expression_transformer.py @@ -6,6 +6,7 @@ from typing import Hashable import vaex +from typing_extensions import TypedDict from mleko.cache.fingerprinters.json_fingerprinter import JsonFingerprinter from mleko.dataset.data_schema import DataSchema, DataType @@ -18,13 +19,26 @@ """A module-level logger for the module.""" +class ExpressionTransformerConfig(TypedDict): + """A type alias for the configuration of the expression transformer.""" + + expression: str + """The `vaex` expression used to create the new feature.""" + + type: DataType + """The data type of the new feature.""" + + is_meta: bool + """A boolean indicating if the new feature is a metadata feature.""" + + class ExpressionTransformer(BaseTransformer): """Creates new features using `vaex` expressions.""" @auto_repr def __init__( self, - expressions: dict[str, tuple[str, DataType, bool]], + expressions: dict[str, ExpressionTransformerConfig], cache_directory: str | Path = "data/expression-transformer", cache_size: int = 1, ) -> None: @@ -39,8 +53,9 @@ def __init__( For example, the expression of `df["a"] + df["b"]` can be extracted using `(df["a"] + df["b"]).expression`. Args: - expressions: A dictionary where the key is the name of the new feature and the value is a tuple containing - the expression, the data type and a boolean indicating if the feature is a metadata feature. In + expressions: A dictionary where the key is the name of the new feature and the value is a dictionary + containing the expression, the data type and a boolean indicating if the feaature is a metadata feature. + The expression must be a valid `vaex` expression that can be evaluated on the DataFrame. cache_directory: The directory where the cache will be stored locally. cache_size: The maximum number of cache entries to keep in the cache. @@ -49,9 +64,9 @@ def __init__( >>> from mleko.dataset.transform import ExpressionTransformer >>> transformer = ExpressionTransformer( ... expressions={ - ... "sum": ("a + b", "numerical", False), - ... "product": ("a * b", "numerical", False), - ... "both_positive": ("(a > 0) & (b > 0)", "boolean", True), + ... "sum": {"expression": "a + b", "type": "numerical", "is_meta": False}, + ... "product": {"expression": "a * b", "type": "numerical", "is_meta": False}, + ... "both_positive": {"expression": "(a > 0) & (b > 0)", "type": "boolean", "is_meta": True}, ... } ... ) >>> df = vaex.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]}) @@ -70,7 +85,7 @@ def __init__( def _fit( self, data_schema: DataSchema, dataframe: vaex.DataFrame - ) -> tuple[DataSchema, dict[str, tuple[str, DataType, bool]]]: + ) -> tuple[DataSchema, dict[str, ExpressionTransformerConfig]]: """No fitting is required for the expression transformer. Args: @@ -94,11 +109,13 @@ def _transform(self, data_schema: DataSchema, dataframe: vaex.DataFrame) -> tupl """ df = dataframe.copy() ds = data_schema.copy() - for feature, (expression, data_type, is_meta) in self._transformer.items(): - logger.info(f"Creating new {data_type!r} feature {feature!r} using expression {expression!r}.") - df[feature] = get_column(df, expression).as_arrow() - if not is_meta: - ds.add_feature(feature, data_type) + for feature, config in self._transformer.items(): + logger.info( + f"Creating new {config['type']!r} feature {feature!r} using expression {config['expression']!r}." + ) + df[feature] = get_column(df, config["expression"]).as_arrow() + if not config["is_meta"]: + ds.add_feature(feature, config["type"]) return ds, df def _fingerprint(self) -> Hashable: diff --git a/tests/dataset/transform/test_expression_transformer.py b/tests/dataset/transform/test_expression_transformer.py index 74841bba..c1d272b7 100644 --- a/tests/dataset/transform/test_expression_transformer.py +++ b/tests/dataset/transform/test_expression_transformer.py @@ -35,9 +35,9 @@ def test_expression_transformer( """Should correctly frequency encode the specified features.""" expression_transformer = ExpressionTransformer( { - "sum": ("astype(a + b + c, 'int32')", "numerical", False), - "product": ("astype(a * b * c, 'int32')", "numerical", False), - "all_positive": ("(a >= 0) & (b >= 0) & (c >= 0)", "boolean", True), + "sum": {"expression": "astype(a + b + c, 'int32')", "type": "numerical", "is_meta": False}, + "product": {"expression": "astype(a * b * c, 'int32')", "type": "numerical", "is_meta": False}, + "all_positive": {"expression": "(a >= 0) & (b >= 0) & (c >= 0)", "type": "boolean", "is_meta": True}, }, cache_directory=temporary_directory, ) @@ -57,8 +57,8 @@ def test_cache( """Should correctly frequency encode features and use cache if possible.""" ExpressionTransformer( { - "sum": ("astype(a + b + where(isna(c), 0, c), 'int32')", "numerical", False), - "product": ("astype(a * b * where(isna(c), 1, c), 'int32')", "numerical", False), + "sum": {"expression": "astype(a + b + c, 'int32')", "type": "numerical", "is_meta": False}, + "product": {"expression": "astype(a * b * c, 'int32')", "type": "numerical", "is_meta": False}, }, cache_directory=temporary_directory, ).fit_transform(example_data_schema, example_vaex_dataframe) @@ -66,8 +66,8 @@ def test_cache( with patch.object(ExpressionTransformer, "_fit_transform") as mocked_fit_transform: ExpressionTransformer( { - "sum": ("astype(a + b + where(isna(c), 0, c), 'int32')", "numerical", False), - "product": ("astype(a * b * where(isna(c), 1, c), 'int32')", "numerical", False), + "sum": {"expression": "astype(a + b + c, 'int32')", "type": "numerical", "is_meta": False}, + "product": {"expression": "astype(a * b * c, 'int32')", "type": "numerical", "is_meta": False}, }, cache_directory=temporary_directory, ).fit_transform(example_data_schema, example_vaex_dataframe)