Skip to content

Commit

Permalink
PROD-301: Fix BinaryEncoder for columns with regex metachars.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: b1eab1687c6668b1d39d091e534d22148d404a9e
  • Loading branch information
pimlock authored and drew committed Jan 13, 2023
1 parent 13ed083 commit 8889230
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/gretel_synthetics/actgan/data_transformer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
import re
import uuid
import warnings

from types import MethodType
from typing import Any, Dict, FrozenSet, List, Optional, Sequence, Union

import numpy as np
import pandas as pd

from category_encoders import BinaryEncoder
from category_encoders import BaseNEncoder, BinaryEncoder
from gretel_synthetics.actgan.structures import (
ActivationFn,
ColumnIdInfo,
Expand Down Expand Up @@ -127,6 +129,8 @@ def _fit(self, data: SeriesOrDFLike) -> None:
data: Data to fit the transformer to.
"""
self.encoder = BinaryEncoder()
_patch_basen_to_integer(self.encoder.base_n_encoder)

data = self._prepare_data(data)
if not isinstance(data, pd.Series):
data = pd.Series(data)
Expand Down Expand Up @@ -214,6 +218,42 @@ def _reverse_transform(self, data: SeriesOrDFLike) -> SeriesOrDFLike:
return transformed_data


def _patch_basen_to_integer(basen_encoder: BaseNEncoder) -> None:
"""
FIXME(PROD-309): Temporary patch for https://github.com/scikit-learn-contrib/category_encoders/issues/392
"""

def _patched_basen_to_integer(self, X, cols, base):
"""
Copied from https://github.com/scikit-learn-contrib/category_encoders/blob/1def42827df4a9404553f41255878c45d754b1a0/category_encoders/basen.py#L266-L281
and applied this fix: https://github.com/scikit-learn-contrib/category_encoders/pull/393/files
"""
out_cols = X.columns.values.tolist()

for col in cols:
col_list = [
col0
for col0 in out_cols
if re.match(re.escape(str(col)) + "_\\d+", str(col0))
]
insert_at = out_cols.index(col_list[0])

if base == 1:
value_array = np.array([int(col0.split("_")[-1]) for col0 in col_list])
else:
len0 = len(col_list)
value_array = np.array([base ** (len0 - 1 - i) for i in range(len0)])
X.insert(insert_at, col, np.dot(X[col_list].values, value_array.T))
X.drop(col_list, axis=1, inplace=True)
out_cols = X.columns.values.tolist()

return X

basen_encoder.basen_to_integer = MethodType(
_patched_basen_to_integer, basen_encoder
)


class DataTransformer:
"""Data Transformer.
Expand Down
20 changes: 20 additions & 0 deletions tests/actgan/test_data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,23 @@ def test_enc_dec_with_mode():
df = pd.DataFrame(data=[[1, 1, 0]], columns=columns)
check = encoder.reverse_transform(df)
assert list(check["foo"])[0] == "A"


def test_encoder_with_regex_metachars():
"""
FIXME(PROD-309): This is a test covering this `category_encoders` issue:
https://github.com/scikit-learn-contrib/category_encoders/issues/392
We need our own fix for now, which can be removed once we migrate to
version with the fix upstream.
"""
col_name = "column*+{} (keep it secret!) [ab12-x]"

df = pd.DataFrame(data={col_name: ["A", "A", "B", "C", "D"]})
encoder = BinaryEncodingTransformer()
transformed = encoder.fit_transform(df, col_name)

transformed_sample = transformed.head(1)
check = encoder.reverse_transform(transformed_sample)

assert check[col_name].equals(pd.Series(["A"]))

0 comments on commit 8889230

Please sign in to comment.