Skip to content

Commit

Permalink
RDS-132: Add DoppelGANger model to gretel_synthetics.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 7b42541dad2fb2d1af3a3bf64dde835eaa347b0b
  • Loading branch information
kboyd committed Apr 15, 2022
1 parent 9eacab5 commit 28a2f92
Show file tree
Hide file tree
Showing 12 changed files with 2,134 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Modules
api/generate.rst
api/batch.rst
utils/index.rst
models/timeseries_dgan.rst


Indices and tables
Expand Down
34 changes: 34 additions & 0 deletions docs/models/timeseries_dgan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Timeseries DGAN
===============

The Timeseries DGAN module contains a PyTorch implementation of the DoppelGANger
model, see https://arxiv.org/abs/1909.13403 for a detailed description of the
model.

.. code-block:: python
import numpy as np
from gretel_synthetics.timeseries_dgan.dgan import DGAN
from gretel_synthetics.timeseries_dgan.config import DGANConfig
attributes = np.random.rand(10000, 3)
features = np.random.rand(10000, 20, 2)
config = DGANConfig(
max_sequence_len=20,
sample_len=5,
batch_size=1000,
epochs=10
)
model = DGAN(config)
model.train(attributes, features)
synthetic_attributes, synthetic_features = model.generate(1000)
.. automodule:: gretel_synthetics.timeseries_dgan.config
:members:

.. automodule:: gretel_synthetics.timeseries_dgan.dgan
:special-members: __init__
:members:
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ sphinx==3.0.3
docutils==0.17.1
mistune==0.8.4
sphinx-rtd-theme
jinja2==3.0.3
-r ../requirements.txt
-r ../utils-requirements.txt
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ sentencepiece==0.1.91
smart_open>=2.1.0,<3.0
tensorflow==2.4.0
tensorflow_privacy==0.5.1
torch==1.11.0
tqdm<5.0
Empty file.
124 changes: 124 additions & 0 deletions src/gretel_synthetics/timeseries_dgan/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from dataclasses import dataclass
from enum import Enum


class OutputType(Enum):
"""Supported variables types.
Determines internal representation of variables and output layers in
generation network.
"""

DISCRETE = 0
CONTINUOUS = 1


class Normalization(Enum):
"""Normalization types for continuous variables.
Determines if a sigmoid (ZERO_ONE) or tanh (MINUSONE_ONE) activation is used
for the output layers in the generation network.
"""

ZERO_ONE = 0
MINUSONE_ONE = 1


@dataclass
class DGANConfig:
"""Config object with parameters for training a DGAN model.
Args:
max_sequence_len: length of time series sequences, variable length
sequences are not supported, so all training and generated data will
have the same length sequences
sample_len: time series steps to generate from each LSTM cell in DGAN,
must be a divisor of max_sequence_len
attribute_noise_dim: length of the GAN noise vectors for attribute
generation
feature_noise_dim: length of GAN noise vectors for feature generation
attribute_num_layers: # of layers in the GAN discriminator network
attribute_num_units: # of units per layer in the GAN discriminator
network
feature_num_layers: # of LSTM layers in the GAN generator network
feature_num_units: # of units per layer in the GAN generator network
use_attribute_discriminator: use separaste discriminator only on
attributes, helps DGAN match attribute distributions, Default: True
normalization: default normalization for continuous variables, used when
metadata output is not specified during DGAN initialization
apply_feature_scaling: scale each continuous variable to [0,1] or [-1,1]
(based on normalization param) before training and rescale to
original range during generation, if False then training data must
be within range and DGAN will only generate values in [0,1] or
[-1,1], Default: True
apply_example_scaling: compute midpoint and halfrange (equivalent to
min/max) for each time series variable and include these as
additional attributes that are generated, this provides better
support for time series with highly variable ranges, e.g., in
network data, a dial-up connection has bandwidth usage in [1kb,
10kb], while a fiber connection is in [100mb, 1gb], Default: True
forget_bias: initialize forget gate bias paramters to 1 in LSTM layers,
when True initialization matches tf1 LSTMCell behavior, otherwise
default pytorch initialization is used, Default: False
gradient_penalty_coef: coefficient for gradient penalty in Wasserstein
loss, Default: 10.0
attribute_gradient_penalty_coef: coefficient for gradient penalty in
Wasserstein loss for the attribute discriminator, Default: 10.0
attribute_loss_coef: coefficient for attribute discriminator loss in
comparison the standard discriminator on attributes and features,
higher values should encourage DGAN to learn attribute
distributions, Default: 1.0
generator_learning_rate: learning rate for Adam optimizer
generator_beta1: Adam param for exponential decay of 1st moment
discriminator_learning_rate: learning rate for Adam optimizer
discriminator_beta1: Adam param for exponential decay of 1st moment
attribute_discriminator_learning_rate: learning rate for Adam optimizer
attribute_discriminator_beta1: Adam param for exponential decay of 1st
moment
batch_size: # of examples used in batches, for both training and
generation
epochs: # of epochs to train model discriminator_rounds: training steps
for the discriminator(s) in each
batch
generator_rounds: training steps for the generator in each batch
cuda: use GPU if available
"""

# Model structure
max_sequence_len: int
sample_len: int

attribute_noise_dim: int = 10
feature_noise_dim: int = 10
attribute_num_layers: int = 3
attribute_num_units: int = 100
feature_num_layers: int = 1
feature_num_units: int = 100
use_attribute_discriminator: bool = True

# Data transformation
normalization: Normalization = Normalization.ZERO_ONE
apply_feature_scaling: bool = True
apply_example_scaling: bool = True

# Model initialization
forget_bias: bool = False

# Loss function
gradient_penalty_coef: float = 10.0
attribute_gradient_penalty_coef: float = 10.0
attribute_loss_coef: float = 1.0

# Training
generator_learning_rate: float = 0.001
generator_beta1: float = 0.5
discriminator_learning_rate: float = 0.001
discriminator_beta1: float = 0.5
attribute_discriminator_learning_rate: float = 0.001
attribute_discriminator_beta1: float = 0.5
batch_size: int = 1024
epochs: int = 400
discriminator_rounds: int = 1
generator_rounds: int = 1

cuda: bool = True

0 comments on commit 28a2f92

Please sign in to comment.