Skip to content

Commit

Permalink
feat: add left_join_transformer (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
premsrii committed Jan 5, 2023
1 parent 53684e9 commit 31fbde0
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ repos:
- id: poetry-export
args:
[
"--dev",
"--with",
"dev",
"--format",
"requirements.txt",
"--output",
Expand Down
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ poetry install
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`ColumnDropperTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.ColumnDropperTransformer)|Drops columns from a dataframe using Pandas drop method.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`DtypeTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.DtypeTransformer)|Transformer that converts a column to a different dtype.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`FunctionsTransformer`]( https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.FunctionsTransformer)|This transformer is a plain wrapper around the [sklearn.preprocessing.FunctionTransformer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html).|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`LeftJoinTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.LeftJoinTransformer)|Uses Pandas merge function to perform a left-join based on the column of a dataframe and the index of another dataframe. The right dataframe is essentially a lookup table.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`MapTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.MapTransformer)|This transformer iterates over all columns in the `features` list and applies the given callback to the column. For this it uses the `pandas.Series.map` method.
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`NaNTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.NaNTransformer)|Replace NaN values with a specified value. Internally Pandas fillna method is used.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`QueryTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.QueryTransformer)|Applies a list of queries to a dataframe. If it operates on a dataset used for supervised learning this transformer should be applied on the dataframe containing `X` and `y`.
Expand Down
57 changes: 37 additions & 20 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -391,26 +391,43 @@ jupyterlab-widgets==3.0.5 ; python_version >= "3.8" and python_version < "3.11"
langcodes==3.3.0 ; python_version >= "3.8" and python_version < "3.11" \
--hash=sha256:4d89fc9acb6e9c8fdef70bcdf376113a3db09b67285d9e1d534de6d8818e7e69 \
--hash=sha256:794d07d5a28781231ac335a1561b8442f8648ca07cd518310aeb45d6f0807ef6
lazy-object-proxy==1.8.0 ; python_version >= "3.8" and python_version < "3.11" \
--hash=sha256:0c1c7c0433154bb7c54185714c6929acc0ba04ee1b167314a779b9025517eada \
--hash=sha256:14010b49a2f56ec4943b6cf925f597b534ee2fe1f0738c84b3bce0c1a11ff10d \
--hash=sha256:4e2d9f764f1befd8bdc97673261b8bb888764dfdbd7a4d8f55e4fbcabb8c3fb7 \
--hash=sha256:4fd031589121ad46e293629b39604031d354043bb5cdf83da4e93c2d7f3389fe \
--hash=sha256:5b51d6f3bfeb289dfd4e95de2ecd464cd51982fe6f00e2be1d0bf94864d58acd \
--hash=sha256:6850e4aeca6d0df35bb06e05c8b934ff7c533734eb51d0ceb2d63696f1e6030c \
--hash=sha256:6f593f26c470a379cf7f5bc6db6b5f1722353e7bf937b8d0d0b3fba911998858 \
--hash=sha256:71d9ae8a82203511a6f60ca5a1b9f8ad201cac0fc75038b2dc5fa519589c9288 \
--hash=sha256:7e1561626c49cb394268edd00501b289053a652ed762c58e1081224c8d881cec \
--hash=sha256:8f6ce2118a90efa7f62dd38c7dbfffd42f468b180287b748626293bf12ed468f \
--hash=sha256:ae032743794fba4d171b5b67310d69176287b5bf82a21f588282406a79498891 \
--hash=sha256:afcaa24e48bb23b3be31e329deb3f1858f1f1df86aea3d70cb5c8578bfe5261c \
--hash=sha256:b70d6e7a332eb0217e7872a73926ad4fdc14f846e85ad6749ad111084e76df25 \
--hash=sha256:c219a00245af0f6fa4e95901ed28044544f50152840c5b6a3e7b2568db34d156 \
--hash=sha256:ce58b2b3734c73e68f0e30e4e725264d4d6be95818ec0a0be4bb6bf9a7e79aa8 \
--hash=sha256:d176f392dbbdaacccf15919c77f526edf11a34aece58b55ab58539807b85436f \
--hash=sha256:e20bfa6db17a39c706d24f82df8352488d2943a3b7ce7d4c22579cb89ca8896e \
--hash=sha256:eac3a9a5ef13b332c059772fd40b4b1c3d45a3a2b05e33a361dee48e54a4dad0 \
--hash=sha256:eb329f8d8145379bf5dbe722182410fe8863d186e51bf034d2075eb8d85ee25b
lazy-object-proxy==1.9.0 ; python_version >= "3.8" and python_version < "3.11" \
--hash=sha256:09763491ce220c0299688940f8dc2c5d05fd1f45af1e42e636b2e8b2303e4382 \
--hash=sha256:0a891e4e41b54fd5b8313b96399f8b0e173bbbfc03c7631f01efbe29bb0bcf82 \
--hash=sha256:189bbd5d41ae7a498397287c408617fe5c48633e7755287b21d741f7db2706a9 \
--hash=sha256:18b78ec83edbbeb69efdc0e9c1cb41a3b1b1ed11ddd8ded602464c3fc6020494 \
--hash=sha256:1aa3de4088c89a1b69f8ec0dcc169aa725b0ff017899ac568fe44ddc1396df46 \
--hash=sha256:212774e4dfa851e74d393a2370871e174d7ff0ebc980907723bb67d25c8a7c30 \
--hash=sha256:2d0daa332786cf3bb49e10dc6a17a52f6a8f9601b4cf5c295a4f85854d61de63 \
--hash=sha256:5f83ac4d83ef0ab017683d715ed356e30dd48a93746309c8f3517e1287523ef4 \
--hash=sha256:659fb5809fa4629b8a1ac5106f669cfc7bef26fbb389dda53b3e010d1ac4ebae \
--hash=sha256:660c94ea760b3ce47d1855a30984c78327500493d396eac4dfd8bd82041b22be \
--hash=sha256:66a3de4a3ec06cd8af3f61b8e1ec67614fbb7c995d02fa224813cb7afefee701 \
--hash=sha256:721532711daa7db0d8b779b0bb0318fa87af1c10d7fe5e52ef30f8eff254d0cd \
--hash=sha256:7322c3d6f1766d4ef1e51a465f47955f1e8123caee67dd641e67d539a534d006 \
--hash=sha256:79a31b086e7e68b24b99b23d57723ef7e2c6d81ed21007b6281ebcd1688acb0a \
--hash=sha256:81fc4d08b062b535d95c9ea70dbe8a335c45c04029878e62d744bdced5141586 \
--hash=sha256:8fa02eaab317b1e9e03f69aab1f91e120e7899b392c4fc19807a8278a07a97e8 \
--hash=sha256:9090d8e53235aa280fc9239a86ae3ea8ac58eff66a705fa6aa2ec4968b95c821 \
--hash=sha256:946d27deaff6cf8452ed0dba83ba38839a87f4f7a9732e8f9fd4107b21e6ff07 \
--hash=sha256:9990d8e71b9f6488e91ad25f322898c136b008d87bf852ff65391b004da5e17b \
--hash=sha256:9cd077f3d04a58e83d04b20e334f678c2b0ff9879b9375ed107d5d07ff160171 \
--hash=sha256:9e7551208b2aded9c1447453ee366f1c4070602b3d932ace044715d89666899b \
--hash=sha256:9f5fa4a61ce2438267163891961cfd5e32ec97a2c444e5b842d574251ade27d2 \
--hash=sha256:b40387277b0ed2d0602b8293b94d7257e17d1479e257b4de114ea11a8cb7f2d7 \
--hash=sha256:bfb38f9ffb53b942f2b5954e0f610f1e721ccebe9cce9025a38c8ccf4a5183a4 \
--hash=sha256:cbf9b082426036e19c6924a9ce90c740a9861e2bdc27a4834fd0a910742ac1e8 \
--hash=sha256:d9e25ef10a39e8afe59a5c348a4dbf29b4868ab76269f81ce1674494e2565a6e \
--hash=sha256:db1c1722726f47e10e0b5fdbf15ac3b8adb58c091d12b3ab713965795036985f \
--hash=sha256:e7c21c95cae3c05c14aafffe2865bbd5e377cfc1348c4f7751d9dc9a48ca4bda \
--hash=sha256:e8c6cfb338b133fbdbc5cfaa10fe3c6aeea827db80c978dbd13bc9dd8526b7d4 \
--hash=sha256:ea806fd4c37bf7e7ad82537b0757999264d5f70c45468447bb2b91afdbe73a6e \
--hash=sha256:edd20c5a55acb67c7ed471fa2b5fb66cb17f61430b7a6b9c3b4a1e40293b1671 \
--hash=sha256:f0117049dd1d5635bbff65444496c90e0baa48ea405125c088e93d9cf4525b11 \
--hash=sha256:f0705c376533ed2a9e5e97aacdbfe04cecd71e0aa84c7c0595d02ef93b6e4455 \
--hash=sha256:f12ad7126ae0c98d601a7ee504c1122bcef553d1d5e0c3bfa77b16b3968d2734 \
--hash=sha256:f2457189d8257dd41ae9b434ba33298aec198e30adf2dcdaaa3a28b9994f6adb \
--hash=sha256:f699ac1c768270c9e384e4cbd268d6e67aebcfae6cd623b4d7c3bfde5a35db59
locket==1.0.0 ; python_version >= "3.8" and python_version < "3.11" \
--hash=sha256:5c0d4c052a8bbbf750e056a8e65ccd309086f4f0f18a2eac306a8dfa4112a632 \
--hash=sha256:b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3
Expand Down
67 changes: 67 additions & 0 deletions src/sk_transformers/generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,70 @@ def __check_for_regex(value: Any) -> bool:
except re.error:
is_valid = False
return is_valid


class LeftJoinTransformer(BaseTransformer):
"""
Performs a database-style left-join using `pd.merge`. This transformer is suitable for
replacing values in a column of a dataframe by looking-up another `pd.DataFrame`
or `pd.Series`. Note that, the join is based on the index of the right dataframe.
Example:
```python
import pandas as pd
from sk_transformers.generic_transformer import LeftJoinTransformer
X = pd.DataFrame({"foo": ["A", "B", "C", "A", "C"]})
lookup_df = pd.Series([1, 2, 3], index=["A", "B", "C"], name="values")
transformer = LeftJoinTransformer([("foo", lookup_df)])
transformer.fit_transform(X)
```
```
foo foo_values
0 A 1
1 B 2
2 C 3
3 A 1
4 C 3
```
Args:
features (List[Tuple[str, Union[pd.Series, pd.DataFrame]]]): A list of tuples
where the first element is the name of the column
and the second element is the look-up dataframe or series.
"""

def __init__(
self, features: List[Tuple[str, Union[pd.Series, pd.DataFrame]]]
) -> None:
super().__init__()
self.features = features

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
"""
Perform a left-join on the given columns of a dataframe with another cooresponding dataframe.
Args:
X (pd.DataFrame): Dataframe containing the columns to be joined on.
Returns:
pd.DataFrame: Dataframe joined on the given columns.
"""

X = check_ready_to_transform(self, X, [feature[0] for feature in self.features])

for (column, lookup_df) in self.features:
lookup_df = LeftJoinTransformer.__prefix_df_column_names(lookup_df, column)
X = pd.merge(X, lookup_df, how="left", left_on=column, right_index=True)

return X

@staticmethod
def __prefix_df_column_names(
df: Union[pd.Series, pd.DataFrame], prefix: str
) -> Union[pd.Series, pd.DataFrame]:
if isinstance(df, pd.Series):
df.name = prefix + "_" + (df.name if df.name else "lookup")
elif isinstance(df, pd.DataFrame):
df.columns = [prefix + "_" + column for column in df.columns]
return df
24 changes: 24 additions & 0 deletions tests/test_transformer/test_generic_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.pipeline import make_pipeline

Expand All @@ -7,6 +8,7 @@
ColumnDropperTransformer,
DtypeTransformer,
FunctionsTransformer,
LeftJoinTransformer,
MapTransformer,
NaNTransformer,
QueryTransformer,
Expand Down Expand Up @@ -241,3 +243,25 @@ def test_nan_transform_in_pipeline(X_nan_values) -> None:
assert X["c"][6] == "missing"
assert pipeline.steps[0][0] == "nantransformer"
assert pipeline.steps[0][1].features[0][1] == -1


def test_left_join_transformer_in_pipeline_for_series(X_categorical) -> None:
lookup_df = pd.Series([1, 2], index=["A1", "A2"], name="values")
pipeline = make_pipeline(LeftJoinTransformer([("a", lookup_df)]))
result = pipeline.fit_transform(X_categorical)
expected = np.array([1, 2, 2, 1, 1, 2, 1, 1])

assert "a_values" in result.columns
assert np.array_equal(result["a_values"].to_numpy(), expected)
assert pipeline.steps[0][0] == "leftjointransformer"


def test_left_join_transformer_in_pipeline_for_dataframe(X_categorical) -> None:
lookup_df = pd.DataFrame({"values": [1, 2]}, index=["A1", "A2"])
pipeline = make_pipeline(LeftJoinTransformer([("a", lookup_df)]))
result = pipeline.fit_transform(X_categorical)
expected = np.array([1, 2, 2, 1, 1, 2, 1, 1])

assert "a_values" in result.columns
assert np.array_equal(result["a_values"].to_numpy(), expected)
assert pipeline.steps[0][0] == "leftjointransformer"

0 comments on commit 31fbde0

Please sign in to comment.