Skip to content

Commit

Permalink
Fix _parallel_transform for ACTGAN.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 94cba19da0b91df443ddedf84523d467d32eecce
  • Loading branch information
pimlock committed Jan 18, 2023
1 parent 8889230 commit e248630
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
6 changes: 2 additions & 4 deletions src/gretel_synthetics/actgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
import warnings

from types import MethodType
from functools import partial
from typing import Any, Dict, FrozenSet, List, Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -249,9 +249,7 @@ def _patched_basen_to_integer(self, X, cols, base):

return X

basen_encoder.basen_to_integer = MethodType(
_patched_basen_to_integer, basen_encoder
)
basen_encoder.basen_to_integer = partial(_patched_basen_to_integer, basen_encoder)


class DataTransformer:
Expand Down
39 changes: 38 additions & 1 deletion tests/actgan/test_data_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import random

import numpy as np
import pandas as pd

from gretel_synthetics.actgan.data_transformer import BinaryEncodingTransformer
from gretel_synthetics.actgan.data_transformer import (
BinaryEncodingTransformer,
DataTransformer,
)


def test_basic_enc_dec():
Expand Down Expand Up @@ -66,3 +71,35 @@ def test_encoder_with_regex_metachars():
check = encoder.reverse_transform(transformed_sample)

assert check[col_name].equals(pd.Series(["A"]))


def test_parallel_transform():
df = pd.DataFrame(
data=[random.choice(["A", "B", "C"]) for _ in range(700)], columns=["foo"]
)
transformer = DataTransformer(binary_encoder_cutoff=2)
transformer.fit(df, discrete_columns=["foo"])

parallel_transform_called = False

def _wrap(original: callable):
"""
Wraps original _parallel_transform method to inspect it was called.
NOTE: This is not using a mock, because that would fail pickling when
creating parallel workers.
"""

def wrapped(self, *args, **kwargs):
nonlocal parallel_transform_called
parallel_transform_called = True
return original(self, *args, **kwargs)

return wrapped

transformer._parallel_transform = _wrap(transformer._parallel_transform)
transformed = transformer.transform(df)

result_df = transformer.inverse_transform(transformed)

assert parallel_transform_called is True
pd.testing.assert_frame_equal(df, result_df)

0 comments on commit e248630

Please sign in to comment.