Skip to content

Commit

Permalink
Pull Request (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBavenstrand committed May 21, 2024
2 parents 2ea8422 + 9ea8583 commit 1974448
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
41 changes: 29 additions & 12 deletions mleko/dataset/transform/expression_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Hashable

import vaex
from typing_extensions import TypedDict

from mleko.cache.fingerprinters.json_fingerprinter import JsonFingerprinter
from mleko.dataset.data_schema import DataSchema, DataType
Expand All @@ -18,13 +19,26 @@
"""A module-level logger for the module."""


class ExpressionTransformerConfig(TypedDict):
"""A type alias for the configuration of the expression transformer."""

expression: str
"""The `vaex` expression used to create the new feature."""

type: DataType
"""The data type of the new feature."""

is_meta: bool
"""A boolean indicating if the new feature is a metadata feature."""


class ExpressionTransformer(BaseTransformer):
"""Creates new features using `vaex` expressions."""

@auto_repr
def __init__(
self,
expressions: dict[str, tuple[str, DataType, bool]],
expressions: dict[str, ExpressionTransformerConfig],
cache_directory: str | Path = "data/expression-transformer",
cache_size: int = 1,
) -> None:
Expand All @@ -39,8 +53,9 @@ def __init__(
For example, the expression of `df["a"] + df["b"]` can be extracted using `(df["a"] + df["b"]).expression`.
Args:
expressions: A dictionary where the key is the name of the new feature and the value is a tuple containing
the expression, the data type and a boolean indicating if the feature is a metadata feature. In
expressions: A dictionary where the key is the name of the new feature and the value is a dictionary
containing the expression, the data type and a boolean indicating if the feaature is a metadata feature.
The expression must be a valid `vaex` expression that can be evaluated on the DataFrame.
cache_directory: The directory where the cache will be stored locally.
cache_size: The maximum number of cache entries to keep in the cache.
Expand All @@ -49,9 +64,9 @@ def __init__(
>>> from mleko.dataset.transform import ExpressionTransformer
>>> transformer = ExpressionTransformer(
... expressions={
... "sum": ("a + b", "numerical", False),
... "product": ("a * b", "numerical", False),
... "both_positive": ("(a > 0) & (b > 0)", "boolean", True),
... "sum": {"expression": "a + b", "type": "numerical", "is_meta": False},
... "product": {"expression": "a * b", "type": "numerical", "is_meta": False},
... "both_positive": {"expression": "(a > 0) & (b > 0)", "type": "boolean", "is_meta": True},
... }
... )
>>> df = vaex.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand All @@ -70,7 +85,7 @@ def __init__(

def _fit(
self, data_schema: DataSchema, dataframe: vaex.DataFrame
) -> tuple[DataSchema, dict[str, tuple[str, DataType, bool]]]:
) -> tuple[DataSchema, dict[str, ExpressionTransformerConfig]]:
"""No fitting is required for the expression transformer.
Args:
Expand All @@ -94,11 +109,13 @@ def _transform(self, data_schema: DataSchema, dataframe: vaex.DataFrame) -> tupl
"""
df = dataframe.copy()
ds = data_schema.copy()
for feature, (expression, data_type, is_meta) in self._transformer.items():
logger.info(f"Creating new {data_type!r} feature {feature!r} using expression {expression!r}.")
df[feature] = get_column(df, expression).as_arrow()
if not is_meta:
ds.add_feature(feature, data_type)
for feature, config in self._transformer.items():
logger.info(
f"Creating new {config['type']!r} feature {feature!r} using expression {config['expression']!r}."
)
df[feature] = get_column(df, config["expression"]).as_arrow()
if not config["is_meta"]:
ds.add_feature(feature, config["type"])
return ds, df

def _fingerprint(self) -> Hashable:
Expand Down
14 changes: 7 additions & 7 deletions tests/dataset/transform/test_expression_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def test_expression_transformer(
"""Should correctly frequency encode the specified features."""
expression_transformer = ExpressionTransformer(
{
"sum": ("astype(a + b + c, 'int32')", "numerical", False),
"product": ("astype(a * b * c, 'int32')", "numerical", False),
"all_positive": ("(a >= 0) & (b >= 0) & (c >= 0)", "boolean", True),
"sum": {"expression": "astype(a + b + c, 'int32')", "type": "numerical", "is_meta": False},
"product": {"expression": "astype(a * b * c, 'int32')", "type": "numerical", "is_meta": False},
"all_positive": {"expression": "(a >= 0) & (b >= 0) & (c >= 0)", "type": "boolean", "is_meta": True},
},
cache_directory=temporary_directory,
)
Expand All @@ -57,17 +57,17 @@ def test_cache(
"""Should correctly frequency encode features and use cache if possible."""
ExpressionTransformer(
{
"sum": ("astype(a + b + where(isna(c), 0, c), 'int32')", "numerical", False),
"product": ("astype(a * b * where(isna(c), 1, c), 'int32')", "numerical", False),
"sum": {"expression": "astype(a + b + c, 'int32')", "type": "numerical", "is_meta": False},
"product": {"expression": "astype(a * b * c, 'int32')", "type": "numerical", "is_meta": False},
},
cache_directory=temporary_directory,
).fit_transform(example_data_schema, example_vaex_dataframe)

with patch.object(ExpressionTransformer, "_fit_transform") as mocked_fit_transform:
ExpressionTransformer(
{
"sum": ("astype(a + b + where(isna(c), 0, c), 'int32')", "numerical", False),
"product": ("astype(a * b * where(isna(c), 1, c), 'int32')", "numerical", False),
"sum": {"expression": "astype(a + b + c, 'int32')", "type": "numerical", "is_meta": False},
"product": {"expression": "astype(a * b * c, 'int32')", "type": "numerical", "is_meta": False},
},
cache_directory=temporary_directory,
).fit_transform(example_data_schema, example_vaex_dataframe)
Expand Down

0 comments on commit 1974448

Please sign in to comment.