In [1]:
%%capture
!pip install numpy pandas matplotlib torch
# TODO: switch to use pypi package gretel-synthetics once released
!pip install git+https://github.com/gretelai/gretel-synthetics.git


In [2]:
import numpy as np
import pandas as pd

from gretel_synthetics.timeseries_dgan.dgan import DGAN
from gretel_synthetics.timeseries_dgan.config import DGANConfig, OutputType, Normalization

# Training from a DataFrame

In [3]:
# Create some random training data data
df = pd.DataFrame(np.random.random(size=(1000,30)))
df.columns = pd.date_range("2022-01-01", periods=30)
# Include an attribute column
df["attribute"] = np.random.randint(0, 3, size=1000)

df

Unnamed: 0,2022-01-01 00:00:00,2022-01-02 00:00:00,2022-01-03 00:00:00,2022-01-04 00:00:00,2022-01-05 00:00:00,2022-01-06 00:00:00,2022-01-07 00:00:00,2022-01-08 00:00:00,2022-01-09 00:00:00,2022-01-10 00:00:00,...,2022-01-22 00:00:00,2022-01-23 00:00:00,2022-01-24 00:00:00,2022-01-25 00:00:00,2022-01-26 00:00:00,2022-01-27 00:00:00,2022-01-28 00:00:00,2022-01-29 00:00:00,2022-01-30 00:00:00,attribute
0,0.560833,0.021828,0.625366,0.820607,0.578291,0.382111,0.208016,0.589131,0.001629,0.260545,...,0.935479,0.023924,0.769155,0.605583,0.516744,0.166947,0.460190,0.898367,0.750524,2
1,0.047354,0.506715,0.055981,0.994086,0.932349,0.350104,0.878353,0.638448,0.123029,0.783326,...,0.074828,0.418722,0.245674,0.288377,0.873725,0.194509,0.839084,0.711169,0.377632,0
2,0.653394,0.894514,0.666302,0.109867,0.919684,0.587311,0.833353,0.205244,0.649074,0.273094,...,0.635373,0.499813,0.500418,0.654058,0.776763,0.226038,0.047210,0.439800,0.749124,2
3,0.218167,0.455448,0.784614,0.951239,0.631821,0.657217,0.884685,0.815268,0.231759,0.591914,...,0.891647,0.753050,0.922394,0.597673,0.210616,0.499074,0.458702,0.759720,0.312400,1
4,0.600659,0.535101,0.703956,0.142726,0.313174,0.586587,0.684656,0.425363,0.134775,0.503620,...,0.960892,0.779443,0.580709,0.971533,0.275259,0.637244,0.663237,0.764521,0.128809,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,0.934307,0.893747,0.564022,0.163139,0.120658,0.629173,0.077333,0.468956,0.153451,0.327377,...,0.849821,0.533830,0.313332,0.990926,0.005402,0.467682,0.241946,0.184720,0.779400,2
996,0.428448,0.745350,0.925675,0.689218,0.074771,0.197755,0.243318,0.373151,0.577873,0.326399,...,0.484315,0.119533,0.372267,0.245859,0.567435,0.063363,0.486922,0.229001,0.028821,2
997,0.109175,0.217878,0.263733,0.900757,0.534955,0.109753,0.597126,0.286169,0.300131,0.700751,...,0.979818,0.995882,0.934597,0.433416,0.895675,0.040983,0.283757,0.895343,0.374732,0
998,0.030816,0.301465,0.704908,0.277258,0.889353,0.132958,0.714270,0.418214,0.124248,0.261260,...,0.780267,0.354577,0.273570,0.786567,0.580439,0.429945,0.141935,0.550360,0.832988,0


In [4]:
# Train the model
model = DGAN(DGANConfig(
    max_sequence_len=30,
    sample_len=3,
    batch_size=1000,
    epochs=10,  # For real data sets, 100-1000 epochs is typical
))

model.train_dataframe(
    df,
    df_attribute_columns=["attribute"],
    attribute_types=[OutputType.DISCRETE],
)

# Generate synthetic data
synthetic_df = model.generate_dataframe(100)

synthetic_df

2022-05-20 17:43:51,017 : MainThread : INFO : epoch: 0
2022-05-20 17:43:51,345 : MainThread : INFO : epoch: 1
2022-05-20 17:43:51,677 : MainThread : INFO : epoch: 2
2022-05-20 17:43:51,998 : MainThread : INFO : epoch: 3
2022-05-20 17:43:52,327 : MainThread : INFO : epoch: 4
2022-05-20 17:43:52,657 : MainThread : INFO : epoch: 5
2022-05-20 17:43:52,979 : MainThread : INFO : epoch: 6
2022-05-20 17:43:53,302 : MainThread : INFO : epoch: 7
2022-05-20 17:43:53,633 : MainThread : INFO : epoch: 8
2022-05-20 17:43:53,959 : MainThread : INFO : epoch: 9


Unnamed: 0,attribute,2022-01-01 00:00:00,2022-01-02 00:00:00,2022-01-03 00:00:00,2022-01-04 00:00:00,2022-01-05 00:00:00,2022-01-06 00:00:00,2022-01-07 00:00:00,2022-01-08 00:00:00,2022-01-09 00:00:00,...,2022-01-21 00:00:00,2022-01-22 00:00:00,2022-01-23 00:00:00,2022-01-24 00:00:00,2022-01-25 00:00:00,2022-01-26 00:00:00,2022-01-27 00:00:00,2022-01-28 00:00:00,2022-01-29 00:00:00,2022-01-30 00:00:00
0,1.0,0.422214,0.432031,0.480368,0.395353,0.446631,0.483186,0.384243,0.466559,0.477707,...,0.489541,0.344399,0.475066,0.506682,0.342995,0.478938,0.493156,0.326783,0.477762,0.497629
1,1.0,0.533715,0.543091,0.579105,0.529386,0.556551,0.580874,0.524565,0.566012,0.584089,...,0.585920,0.483700,0.581420,0.594063,0.482967,0.578164,0.595518,0.477830,0.586937,0.600302
2,1.0,0.485039,0.482299,0.509467,0.470266,0.487701,0.519129,0.470279,0.506402,0.519168,...,0.531473,0.440687,0.504846,0.534355,0.432567,0.513546,0.525229,0.431557,0.509433,0.525629
3,1.0,0.479240,0.497511,0.502894,0.461390,0.497665,0.518905,0.446949,0.508497,0.528179,...,0.533625,0.403045,0.521507,0.556600,0.398584,0.525295,0.545459,0.409793,0.527263,0.536226
4,0.0,0.293686,0.279307,0.322273,0.281750,0.301896,0.333495,0.260899,0.321256,0.338657,...,0.357883,0.194592,0.362058,0.345538,0.203562,0.364923,0.333071,0.199174,0.352475,0.359166
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,0.0,0.442134,0.444689,0.490872,0.422184,0.467230,0.491960,0.406500,0.480944,0.491579,...,0.529950,0.373593,0.511844,0.514150,0.368716,0.492807,0.515326,0.367020,0.483560,0.519339
96,1.0,0.467014,0.471248,0.510128,0.450343,0.485701,0.504931,0.440208,0.492922,0.510640,...,0.527472,0.388823,0.492023,0.536090,0.385792,0.512149,0.540428,0.378828,0.511057,0.545419
97,1.0,0.405683,0.399558,0.424499,0.382327,0.408988,0.447265,0.374020,0.416329,0.444756,...,0.470237,0.326809,0.449380,0.477480,0.325117,0.450786,0.465375,0.311560,0.460538,0.473520
98,1.0,0.555681,0.557614,0.572440,0.557323,0.567699,0.571970,0.550145,0.566753,0.572892,...,0.577986,0.541558,0.577835,0.579436,0.538783,0.574859,0.577466,0.532367,0.572600,0.582335


# Training from numpy arrays

In [5]:
# Create some random training data
attributes = np.random.randint(0, 3, size=(1000,3))
features = np.random.random(size=(1000,20,2))

In [6]:
attributes

array([[2, 0, 2],
       [2, 0, 1],
       [0, 0, 0],
       ...,
       [2, 1, 1],
       [2, 1, 0],
       [2, 1, 0]])

In [7]:
features

array([[[0.84864839, 0.90278846],
        [0.64863848, 0.73860317],
        [0.6389165 , 0.48596729],
        ...,
        [0.92461491, 0.54725386],
        [0.23179008, 0.91517007],
        [0.16173921, 0.43569247]],

       [[0.6273537 , 0.6702016 ],
        [0.11885196, 0.24677543],
        [0.67649867, 0.19999652],
        ...,
        [0.32574603, 0.88844569],
        [0.6378902 , 0.57453578],
        [0.56277615, 0.29637052]],

       [[0.1869226 , 0.02761921],
        [0.26923597, 0.5749236 ],
        [0.63814811, 0.61559268],
        ...,
        [0.74082772, 0.42674   ],
        [0.84911351, 0.77667221],
        [0.25639317, 0.10132389]],

       ...,

       [[0.61729527, 0.83193035],
        [0.21002377, 0.98979698],
        [0.5390336 , 0.53623807],
        ...,
        [0.65082   , 0.1162104 ],
        [0.68430003, 0.71106839],
        [0.38184098, 0.67575487]],

       [[0.17717818, 0.47086302],
        [0.07198823, 0.05735601],
        [0.24222383, 0.91056876],
        .

In [8]:
# Train the model
model = DGAN(DGANConfig(
    max_sequence_len=20,
    sample_len=4,
    batch_size=1000,
    epochs=10,  # For real data sets, 100-1000 epochs is typical
))

model.train_numpy(
    attributes, features,
    attribute_types = [OutputType.DISCRETE] * 3,
    feature_types = [OutputType.CONTINUOUS] * 2
)

# Generate synthetic data
synthetic_attributes, synthetic_features = model.generate_numpy(1000)

2022-05-20 17:43:54,397 : MainThread : INFO : epoch: 0
2022-05-20 17:43:54,664 : MainThread : INFO : epoch: 1
2022-05-20 17:43:54,919 : MainThread : INFO : epoch: 2
2022-05-20 17:43:55,179 : MainThread : INFO : epoch: 3
2022-05-20 17:43:55,434 : MainThread : INFO : epoch: 4
2022-05-20 17:43:55,700 : MainThread : INFO : epoch: 5
2022-05-20 17:43:55,950 : MainThread : INFO : epoch: 6
2022-05-20 17:43:56,211 : MainThread : INFO : epoch: 7
2022-05-20 17:43:56,460 : MainThread : INFO : epoch: 8
2022-05-20 17:43:56,730 : MainThread : INFO : epoch: 9


In [9]:
synthetic_attributes

array([[2, 1, 2],
       [2, 1, 2],
       [1, 0, 1],
       ...,
       [0, 0, 2],
       [2, 1, 2],
       [0, 1, 2]])

In [10]:
synthetic_features

array([[[0.52187884, 0.577059  ],
        [0.48102787, 0.5524444 ],
        [0.4939233 , 0.512281  ],
        ...,
        [0.43824157, 0.5303508 ],
        [0.48242828, 0.45562354],
        [0.40398213, 0.50301903]],

       [[0.6135011 , 0.5808287 ],
        [0.5716421 , 0.5632861 ],
        [0.5899121 , 0.51765805],
        ...,
        [0.5465827 , 0.5548906 ],
        [0.57716244, 0.47822833],
        [0.49599853, 0.49029127]],

       [[0.6397109 , 0.5311664 ],
        [0.585794  , 0.50538105],
        [0.6048051 , 0.45897728],
        ...,
        [0.55994797, 0.48053992],
        [0.59392625, 0.43046257],
        [0.511215  , 0.45967528]],

       ...,

       [[0.55047446, 0.5590288 ],
        [0.46911034, 0.52786964],
        [0.5050881 , 0.47043025],
        ...,
        [0.44122034, 0.5046226 ],
        [0.49876353, 0.4193969 ],
        [0.39266738, 0.47686806]],

       [[0.6190106 , 0.55931777],
        [0.5734461 , 0.54048246],
        [0.5921291 , 0.499037  ],
        .