Skip to content

Commit

Permalink
Tidy type annotations and remove unused test variables.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 49b8748190636a6a560534d7dab1bde937249073
  • Loading branch information
kboyd committed Jun 15, 2022
1 parent 93a6d05 commit 53e1881
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 19 deletions.
8 changes: 4 additions & 4 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
level=logging.INFO,
)

NumpyArrayPair = Tuple[np.ndarray, np.ndarray]
AttributeFeaturePair = Tuple[Optional[np.ndarray], np.ndarray]
NumpyArrayTriple = Tuple[np.ndarray, np.ndarray, np.ndarray]


Expand Down Expand Up @@ -276,7 +276,7 @@ def generate_numpy(
n: Optional[int] = None,
attribute_noise: Optional[torch.Tensor] = None,
feature_noise: Optional[torch.Tensor] = None,
) -> NumpyArrayPair:
) -> AttributeFeaturePair:
"""Generate synthetic data from DGAN model.
Once trained, a DGAN model can generate arbitrary amounts of
Expand Down Expand Up @@ -647,7 +647,7 @@ def _train(

def _generate(
self, attribute_noise: torch.Tensor, feature_noise: torch.Tensor
) -> Union[NumpyArrayPair, NumpyArrayTriple]:
) -> NumpyArrayTriple:
"""Internal method for generating from a DGAN model.
Returns data in the internal representation, including additional
Expand Down Expand Up @@ -800,7 +800,7 @@ def _extract_from_dataframe(
df: pd.DataFrame,
attribute_columns: Optional[List[Union[str, int]]] = None,
feature_columns: Optional[List[Union[str, int]]] = None,
) -> NumpyArrayPair:
) -> AttributeFeaturePair:
"""Extract attribute and feature arrays from a single pandas DataFrame
Note this method only supports time series of 1 variable where the time
Expand Down
5 changes: 1 addition & 4 deletions src/gretel_synthetics/timeseries_dgan/torch_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,7 @@ def _make_attribute_generator(

def forward(
self, attribute_noise: torch.Tensor, feature_noise: torch.Tensor
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Apply module to input.
Args:
Expand Down
15 changes: 8 additions & 7 deletions src/gretel_synthetics/timeseries_dgan/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create_outputs_from_data(
normalization: Normalization,
apply_feature_scaling: bool = False,
apply_example_scaling: bool = False,
) -> Tuple[List[Output], List[Output]]:
) -> Tuple[Optional[List[Output]], List[Output]]:
"""Create output metadata from data.
Args:
Expand Down Expand Up @@ -201,7 +201,7 @@ def rescale_inverse(


def transform(
original_data: np.ndarray,
original_data: Optional[np.ndarray],
outputs: List[Output],
variable_dim_index: int,
num_examples: Optional[int] = None,
Expand All @@ -217,7 +217,7 @@ def transform(
apply_example_scaling is True
Args:
original_data: data to transform, 2d or 3d numpy array
original_data: data to transform, 2d or 3d numpy array, or None
outputs: Output metadata for each variable
variable_dim_index: dimension of numpy array that contains the
variables, for 2d numpy arrays this should be 1, for 3d should be 2
Expand All @@ -231,9 +231,10 @@ def transform(
Internal representation of data. A single numpy array if the input was a
2d array or if no outputs have apply_example_scaling=True. A tuple of
features, additional_attributes is returned when transforming features
(a 3d numpy array) and example scaling is used. If the input data is a
nan-filled tensor, then a single numpy array filled with nan's that has
the first dimension shape of the number examples of the feature vector is returned.
(a 3d numpy array) and example scaling is used. If the input data is
None, then a single numpy array filled with nan's that has the first
dimension shape of the number examples of the feature vector is
returned.
"""
additional_attribute_parts = []
parts = []
Expand Down Expand Up @@ -342,7 +343,7 @@ def inverse_transform(
outputs: List[Output],
variable_dim_index: int,
additional_attributes: Optional[np.ndarray] = None,
) -> np.ndarray:
) -> Optional[np.ndarray]:
"""Invert transform to map back to original space.
Args:
Expand Down
5 changes: 1 addition & 4 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,11 @@ def test_train_numpy(
itertools.product([False, True], [False, True]),
)
def test_train_numpy_no_attributes_1(
attribute_data,
feature_data,
config: DGANConfig,
use_attribute_discriminator,
is_normalized,
):
attributes, attribute_types = attribute_data
features, feature_types = feature_data

config.use_attribute_discriminator = use_attribute_discriminator
Expand All @@ -220,6 +218,7 @@ def test_train_numpy_no_attributes_1(

attributes, features = dg.generate_numpy(18)

assert attributes == None
assert features.shape == (18, 20, 2)


Expand Down Expand Up @@ -434,7 +433,6 @@ def test_save_and_load(
itertools.product([False, True], [False, True], [10, 25], [2, 5]),
)
def test_save_and_load_no_attributes(
attribute_data,
feature_data,
config: DGANConfig,
tmp_path,
Expand All @@ -443,7 +441,6 @@ def test_save_and_load_no_attributes(
noise_dim,
sample_len,
):
attributes, attribute_types = attribute_data
features, feature_types = feature_data

config.epochs = 1
Expand Down

0 comments on commit 53e1881

Please sign in to comment.