Skip to content

Commit

Permalink
Fix ZERO_ONE normalization bug in DGAN and add transformation tests.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 0d0a6f9229440a47d634188d7de78f0984abf290
  • Loading branch information
kboyd committed Apr 21, 2022
1 parent fd313c7 commit 5e78b94
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/gretel_synthetics/timeseries_dgan/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,19 @@ def rescale(
original: data in original space
normalization: output range for scaling, ZERO_ONE or MINUSONE_ONE
global_min: minimum to use for scaling, either a scalar or has same
dimension as original.shape[0] for scaling each time series
independently
shape as original
global_max: maximum to use for scaling, either a scalar or has same
dimension as original.shape[0]
shape as original
Returns:
Data in transformed space
"""

range = np.maximum(global_max - global_min, 1e-6)
if normalization == Normalization.ZERO_ONE:
return (original - global_min) / (global_max - global_min)
return (original - global_min) / range
elif normalization == Normalization.MINUSONE_ONE:
return (2.0 * (original - global_min) / (global_max - global_min + 1e-6)) - 1.0
return (2.0 * (original - global_min) / range) - 1.0


def rescale_inverse(
Expand All @@ -191,10 +192,11 @@ def rescale_inverse(
Returns:
Data in original space
"""
range = global_max - global_min
if normalization == Normalization.ZERO_ONE:
return transformed * (global_max - global_min) + global_min
return transformed * range + global_min
elif normalization == Normalization.MINUSONE_ONE:
return ((transformed + 1) / 2) * (global_max - global_min) + global_min
return ((transformed + 1.0) / 2.0) * range + global_min


def transform(
Expand All @@ -219,8 +221,8 @@ def transform(
Returns:
Internal representation of data. A single numpy array if the input was a
2d array or if no outputs have apply_example_scaling=True. A tuple of
additional_attributes, features is returned when transforming features
(a 3d numpy array) and example scaling is usd.
features, additional_attributes is returned when transforming features
(a 3d numpy array) and example scaling is used.
"""

additional_attribute_parts = []
Expand Down
118 changes: 118 additions & 0 deletions tests/timeseries_dgan/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from dataclasses import FrozenInstanceError

import numpy as np
import pytest

from gretel_synthetics.timeseries_dgan.config import Normalization
from gretel_synthetics.timeseries_dgan.transformations import (
ContinuousOutput,
DiscreteOutput,
inverse_transform,
rescale,
rescale_inverse,
transform,
)


Expand Down Expand Up @@ -39,3 +44,116 @@ def test_output():
apply_feature_scaling=True,
apply_example_scaling=True,
)


def test_rescale_and_inverse():
original = np.array([1.5, 3, -1.0])
global_min = np.min(original)
global_max = np.max(original)

scaled = rescale(original, Normalization.ZERO_ONE, global_min, global_max)
np.testing.assert_allclose(scaled, [0.625, 1.0, 0.0])
inversed = rescale_inverse(scaled, Normalization.ZERO_ONE, global_min, global_max)
np.testing.assert_allclose(inversed, original)

scaled = rescale(original, Normalization.MINUSONE_ONE, global_min, global_max)
np.testing.assert_allclose(scaled, [0.25, 1.0, -1.0])
inversed = rescale_inverse(
scaled, Normalization.MINUSONE_ONE, global_min, global_max
)
np.testing.assert_allclose(inversed, original)


def test_rescale_and_inverse_by_example():
original = np.array(
[
[1.5, 3, -1.0],
[10, 20, 30],
[1000, 1000, 1000.0],
[-0.1, -0.3, -0.5],
]
)

mins = np.broadcast_to(np.min(original, axis=1).reshape(4, 1), (4, 3))
maxes = np.broadcast_to(np.max(original, axis=1).reshape(4, 1), (4, 3))

scaled = rescale(original, Normalization.ZERO_ONE, mins, maxes)
expected = [
[0.625, 1.0, 0.0],
[0.0, 0.5, 1.0],
[0.0, 0.0, 0.0],
[1.0, 0.5, 0.0],
]
np.testing.assert_allclose(scaled, expected)
inversed = rescale_inverse(scaled, Normalization.ZERO_ONE, mins, maxes)
np.testing.assert_allclose(inversed, original)

scaled = rescale(original, Normalization.MINUSONE_ONE, mins, maxes)
expected = [
[0.25, 1.0, -1.0],
[-1.0, 0.0, 1.0],
[-1.0, -1.0, -1.0],
[1.0, 0.0, -1.0],
]
np.testing.assert_allclose(scaled, expected)
inversed = rescale_inverse(scaled, Normalization.MINUSONE_ONE, mins, maxes)
np.testing.assert_allclose(inversed, original)


@pytest.mark.parametrize(
"normalization", [Normalization.ZERO_ONE, Normalization.MINUSONE_ONE]
)
def test_transform_and_inverse_attributes(normalization):
n = 100
attributes = np.stack(
(
np.random.rand(n) * 1000.0 + 500.0,
np.random.randint(0, 2, size=n),
np.random.rand(n) * 0.1 - 10.0,
np.random.randint(0, 5, size=n),
np.zeros(n) + 2.0,
),
axis=1,
)

outputs = [
ContinuousOutput("a", normalization, 500.0, 1500.0, True, False),
DiscreteOutput("b", 2),
ContinuousOutput("c", normalization, -10.0, -9.9, True, False),
DiscreteOutput("d", 5),
ContinuousOutput("e", normalization, 2.0, 2.0, True, False),
]
transformed = transform(attributes, outputs, 1)
assert transformed.shape == (n, 10)

inversed = inverse_transform(transformed, outputs, 1)
np.testing.assert_allclose(inversed, attributes)


@pytest.mark.parametrize(
"normalization", [Normalization.ZERO_ONE, Normalization.MINUSONE_ONE]
)
def test_transform_and_inverse_features(normalization):
n = 100
features = np.stack(
(
np.random.rand(n, 10) * 1000.0 + 500.0,
np.random.rand(n, 10) * 5.0,
np.random.randint(0, 3, size=(n, 10)),
),
axis=2,
)
assert features.shape == (100, 10, 3)

outputs = [
ContinuousOutput("a", normalization, 500.0, 1500.0, True, True),
ContinuousOutput("b", normalization, 0.0, 5.0, True, True),
DiscreteOutput("c", 3),
]

transformed, additional_attributes = transform(features, outputs, 2)
assert transformed.shape == (100, 10, 5)
assert additional_attributes.shape == (100, 4)

inversed = inverse_transform(transformed, outputs, 2, additional_attributes)
np.testing.assert_allclose(inversed, features)

0 comments on commit 5e78b94

Please sign in to comment.