Skip to content

Commit

Permalink
Add recent(-ish) top-level packages to pyproject.toml config
Browse files Browse the repository at this point in the history
GitOrigin-RevId: f0cb727f9894f7a856ef78724536f453ba3af099
  • Loading branch information
misberner committed Apr 2, 2024
1 parent c69ffcf commit 1847781
Show file tree
Hide file tree
Showing 21 changed files with 63 additions and 39 deletions.
25 changes: 13 additions & 12 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@
import pandas as pd
import torch

from packaging import version
from torch import optim
from torch.nn import (
BatchNorm1d,
Dropout,
functional,
LeakyReLU,
Linear,
Module,
ReLU,
Sequential,
)

from gretel_synthetics.actgan.base import BaseSynthesizer, random_state
from gretel_synthetics.actgan.column_encodings import (
BinaryColumnEncoding,
Expand All @@ -21,18 +34,6 @@
)
from gretel_synthetics.actgan.train_data import TrainData
from gretel_synthetics.typing import DFLike
from packaging import version
from torch import optim
from torch.nn import (
BatchNorm1d,
Dropout,
functional,
LeakyReLU,
Linear,
Module,
ReLU,
Sequential,
)

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down
8 changes: 5 additions & 3 deletions src/gretel_synthetics/actgan/actgan_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@
import numpy as np
import pandas as pd

from rdt.transformers import BaseTransformer
from sdv.tabular.base import BaseTabularModel

from gretel_synthetics.actgan.actgan import ACTGANSynthesizer
from gretel_synthetics.actgan.columnar_df import ColumnarDF
from gretel_synthetics.actgan.structures import ConditionalVectorType
from gretel_synthetics.detectors.sdv import SDVTableMetadata
from gretel_synthetics.utils import rdt_patches, torch_utils
from rdt.transformers import BaseTransformer
from sdv.tabular.base import BaseTabularModel

if TYPE_CHECKING:
from gretel_synthetics.actgan.structures import EpochInfo
from numpy.random import RandomState
from sdv.constraints import Constraint
from sdv.metadata import Metadata
from torch import Generator

from gretel_synthetics.actgan.structures import EpochInfo

EPOCH_CALLBACK = "epoch_callback"

logging.basicConfig()
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/actgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import pandas as pd

from rdt.transformers import BinaryEncoder, OneHotEncoder

from gretel_synthetics.actgan.column_encodings import (
BinaryColumnEncoding,
FloatColumnEncoding,
Expand All @@ -22,7 +24,6 @@
ClusterBasedNormalizer,
)
from gretel_synthetics.typing import DFLike
from rdt.transformers import BinaryEncoder, OneHotEncoder

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/actgan/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
if TYPE_CHECKING:
import numpy as np

from gretel_synthetics.actgan.column_encodings import ColumnEncoding
from rdt.transformers.base import BaseTransformer

from gretel_synthetics.actgan.column_encodings import ColumnEncoding


class ColumnType(str, Enum):
CONTINUOUS = "continuous"
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/actgan/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import pandas as pd

from category_encoders import BaseNEncoder, BinaryEncoder
from gretel_synthetics.typing import ListOrSeriesOrDF, SeriesOrDFLike
from rdt.transformers import BaseTransformer
from rdt.transformers import ClusterBasedNormalizer as RDTClusterBasedNormalizer
from rdt.transformers import FloatFormatter

from gretel_synthetics.typing import ListOrSeriesOrDF, SeriesOrDFLike

MODE = "mode"
VALID_ROUNDING_MODES = (MODE,)

Expand Down
8 changes: 5 additions & 3 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
from typing import List, Optional, Tuple, Type, Union

import cloudpickle
import gretel_synthetics.const as const
import numpy as np
import pandas as pd

from pandas.errors import EmptyDataError
from tqdm.auto import tqdm

import gretel_synthetics.const as const

from gretel_synthetics.config import (
BaseConfig,
config_from_model_dir,
Expand All @@ -44,8 +48,6 @@
from gretel_synthetics.generate import generate_text, GenText, SeedingGenerator
from gretel_synthetics.tokenizers import BaseTokenizerTrainer
from gretel_synthetics.train import train
from pandas.errors import EmptyDataError
from tqdm.auto import tqdm

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from pathlib import Path
from typing import Callable, Optional, TYPE_CHECKING

import gretel_synthetics.const as const
import tensorflow as tf

import gretel_synthetics.const as const

from gretel_synthetics.tensorflow.generator import TensorFlowGenerator
from gretel_synthetics.tensorflow.train import train_rnn

Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/detectors/sdv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import numpy as np
import pandas as pd

from gretel_synthetics.detectors.dates import detect_datetimes
from rdt.transformers import BaseTransformer
from rdt.transformers.datetime import UnixTimestampEncoder

from gretel_synthetics.detectors.dates import detect_datetimes

if TYPE_CHECKING:
from gretel_synthetics.detectors.dates import DateTimeColumn

Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Callable, Optional, Union

from smart_open import open as smart_open

from gretel_synthetics.batch import DataFrameBatch, MAX_INVALID
from gretel_synthetics.config import config_from_model_dir
from gretel_synthetics.generate import generate_text
from gretel_synthetics.utils.tar_util import safe_extractall
from smart_open import open as smart_open

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/tensorflow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import pandas as pd
import tensorflow as tf

from tqdm import tqdm

from gretel_synthetics.const import (
METRIC_ACCURACY,
METRIC_DELTA,
Expand All @@ -30,7 +32,6 @@
from gretel_synthetics.tensorflow.model import build_model, load_model
from gretel_synthetics.tokenizers import BaseTokenizer
from gretel_synthetics.train import EpochState
from tqdm import tqdm

if TYPE_CHECKING:
from gretel_synthetics.config import TensorFlowConfig
Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import pandas as pd
import torch

from torch.utils.data import DataLoader, Dataset, TensorDataset

from gretel_synthetics.errors import DataError, InternalError, ParameterError
from gretel_synthetics.timeseries_dgan.config import DfStyle, DGANConfig, OutputType
from gretel_synthetics.timeseries_dgan.structures import ProgressInfo
Expand All @@ -67,7 +69,6 @@
Output,
transform,
)
from torch.utils.data import DataLoader, Dataset, TensorDataset

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion src/gretel_synthetics/timeseries_dgan/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import numpy as np

from category_encoders import BinaryEncoder, OneHotEncoder
from gretel_synthetics.timeseries_dgan.config import Normalization, OutputType
from scipy.stats import mode

from gretel_synthetics.timeseries_dgan.config import Normalization, OutputType


def _new_uuid() -> str:
"""Return a random uuid prefixed with 'gretel-'."""
Expand Down
6 changes: 4 additions & 2 deletions src/gretel_synthetics/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@
)

import cloudpickle
import gretel_synthetics.const as const
import numpy as np
import sentencepiece as spm

from gretel_synthetics.errors import ParameterError
from smart_open import open as smart_open

import gretel_synthetics.const as const

from gretel_synthetics.errors import ParameterError

if TYPE_CHECKING:
from gretel_synthetics.config import BaseConfig
else:
Expand Down
3 changes: 2 additions & 1 deletion tests-integration/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
import pandas as pd
import pytest

from smart_open import open as smart_open

from gretel_synthetics.batch import DataFrameBatch, GenerationProgress
from gretel_synthetics.generate_utils import DataFileGenerator
from gretel_synthetics.utils.tar_util import safe_extractall
from smart_open import open as smart_open

BATCH_MODELS = [
"https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/safecast-batch-sp-0-14.tar.gz",
Expand Down
3 changes: 2 additions & 1 deletion tests-integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from pathlib import Path

import gretel_synthetics.const as const
import pandas as pd
import pytest

import gretel_synthetics.const as const

from gretel_synthetics.batch import DataFrameBatch, PATH_HOLDER
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.errors import DataError
Expand Down
3 changes: 2 additions & 1 deletion tests/actgan/test_actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import pandas as pd
import pytest

from pandas.api.types import is_number

from gretel_synthetics.actgan import ACTGAN
from gretel_synthetics.actgan.data_transformer import BinaryEncodingTransformer
from gretel_synthetics.actgan.structures import ConditionalVectorType
from pandas.api.types import is_number


@pytest.fixture
Expand Down
5 changes: 3 additions & 2 deletions tests/detectors/test_sdv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import pandas as pd
import pytest

from gretel_synthetics.detectors.dates import DateTimeColumn, DateTimeColumns
from gretel_synthetics.detectors.sdv import EmptyFieldTransformer, SDVTableMetadata
from rdt import HyperTransformer
from rdt.transformers.datetime import UnixTimestampEncoder
from sdv import Table

from gretel_synthetics.detectors.dates import DateTimeColumn, DateTimeColumns
from gretel_synthetics.detectors.sdv import EmptyFieldTransformer, SDVTableMetadata


def _create_info() -> DateTimeColumns:
return DateTimeColumns(columns={"footime": DateTimeColumn("footime", "%Y-%m-%d")})
Expand Down
3 changes: 2 additions & 1 deletion tests/test_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from copy import deepcopy
from pathlib import Path

import gretel_synthetics.tokenizers as tok
import pytest

import gretel_synthetics.tokenizers as tok

from gretel_synthetics.config import BaseConfig
from gretel_synthetics.tokenizers import VocabSizeTooSmall

Expand Down
5 changes: 3 additions & 2 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import pandas as pd
import pytest

from pandas.api.types import is_numeric_dtype, is_object_dtype
from pandas.testing import assert_frame_equal

from gretel_synthetics.errors import DataError, ParameterError
from gretel_synthetics.timeseries_dgan.config import (
DfStyle,
Expand All @@ -27,8 +30,6 @@
ContinuousOutput,
OneHotEncodedOutput,
)
from pandas.api.types import is_numeric_dtype, is_object_dtype
from pandas.testing import assert_frame_equal


@pytest.fixture
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_rdt_float_formatter_orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import pandas as pd
import pytest

from gretel_synthetics.utils.rdt_patches import patch_float_formatter_rounding_bug
from rdt.transformers.null import NullTransformer
from rdt.transformers.numerical import FloatFormatter

from gretel_synthetics.utils.rdt_patches import patch_float_formatter_rounding_bug

with patch_float_formatter_rounding_bug():
# This is the original suite of tests for the FloatFormatter from rdt.
# Source code is copied from
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_rdt_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import numpy as np
import pandas as pd

from rdt.transformers.numerical import FloatFormatter

from gretel_synthetics.utils.rdt_patches import (
_patched_float_formatter_reverse_transform,
patch_float_formatter_rounding_bug,
)
from rdt.transformers.numerical import FloatFormatter


def test_original_rounding_bug_upstream():
Expand Down

0 comments on commit 1847781

Please sign in to comment.