In [1]:
import os
os.chdir('..')

In [2]:
from datgan.datgan import DATGAN

import numpy as np
import pandas as pd
import networkx as nx

# For the Python notebook
%matplotlib inline
%reload_ext autoreload
%autoreload 2






# Load the original data

In [3]:
df = pd.read_csv('example/data/CMAP.csv', index_col=False)

In [4]:
df.head()

Unnamed: 0,choice,travel_dow,trip_purpose,distance,hh_vehicles,hh_size,hh_bikes,hh_descr,hh_income,gender,age,license,education_level,work_status,departure_time
0,drive,7,HOME_OTHER,3.93477,2,3,3,detached,6,0,30,1,4,PTE,20.166667
1,drive,2,SHOPPING,0.31557,3,3,3,detached,7,0,54,1,5,FTE,17.5
2,drive,2,SHOPPING,0.28349,1,1,0,detached,3,0,80,1,3,PTE,9.333333
3,drive,2,OTHER,0.69417,2,2,0,detached,5,1,42,1,5,FTE,13.783333
4,passenger,1,SHOPPING,4.30666,2,2,1,detached,4,0,32,0,3,Unemployed,11.566667


# DAG

We need to define the DAG using the library `networkx` and the continuous columns.

In [5]:
continuous_columns = ["distance", "age", "departure_time"]

In [6]:
graph = nx.DiGraph()
graph.add_edges_from([
    ("age", "license"),
    ("age", "education_level"),
    ("gender", "work_status"),
    ("education_level", "work_status"),
    ("education_level", "hh_income"),
    ("work_status", "hh_income"),
    ("hh_income", "hh_descr"),
    ("hh_income", "hh_size"),
    ("hh_size", "hh_vehicles"),
    ("hh_size", "hh_bikes"),
    ("work_status", "trip_purpose"),
    ("trip_purpose", "departure_time"),
    ("trip_purpose", "distance"),
    ("travel_dow", "choice"),
    ("distance", "choice"),
    ("departure_time", "choice"),
    ("hh_vehicles", "choice"),
    ("hh_bikes", "choice"),
    ("license", "choice"),
    ("education_level", "hh_size"),
    ("work_status", "hh_descr"),
    ("work_status", "hh_size"),
    ("hh_income", "hh_bikes"),
    ("hh_income", "hh_vehicles"),
    ("trip_purpose", "choice")
])

# Training the DATGAN

In [7]:
output_folder = './example/output/'

In [8]:
datgan = DATGAN(output=output_folder, max_epoch=1000)

It is possible to preprocess the data and save it somewhere. Since it takes a bit of time to do it, it helps to test multiple models faster. 

In [9]:
datgan.preprocess(df, continuous_columns, preprocessed_data_path='./example/encoded_data')

[32m[0112 13:11:15 @datgan.py:202][0m Preprocessed data have been loaded!


If the data has been preprocessed, you need to provide the path where it was saved. The model will then the preprocessed data and work with it.

In [10]:
datgan.fit(df, graph, continuous_columns, preprocessed_data_path='./example/encoded_data')

[32m[0112 13:11:15 @datgan.py:202][0m Preprocessed data have been loaded!



[32m[0112 13:11:15 @input_source.py:222][0m Setting up the queue 'QueueInput/input_queue' for CPU prefetching ...








Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
[32m[0112 13:11:15 @DATGANSynthesizer.py:138][0m [91mCreating cell for age (in-edges: 0; ancestors: 0)
Instructions for updating:
Please use `layer.add_weight` method instead.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
[32m[0112 13:11:15 @registry.py:126][0m gen/LSTM/age/FC input: [500, 100]

Instructions for updating:
Please use `layer.__call__` method instead.
[32m[0112 13:11:15 @registry.py:134][0m gen/LSTM/age/FC output: [500, 50]
[32m[0112 13:11:15 @registry.py:126][0m gen/LSTM/age/FC_val input: [500, 50]
[32m[0112 13:11:15 @registry.py:134][0m gen/LSTM/age/FC_val output

[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/age-gender/FC_noise output: [None, 200]
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/work_status/FC input: [500, 100]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/work_status/FC output: [500, 50]
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/work_status/FC_prob input: [500, 50]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/work_status/FC_prob output: [500, 8]
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/work_status/FC_input input: [500, 8]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/work_status/FC_input output: [500, 100]
[32m[0112 13:11:16 @DATGANSynthesizer.py:138][0m [91mCreating cell for hh_income (in-edges: 2; ancestors: 4)
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/concat-hh_income/FC_inputs input: [500, 200]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/concat-hh_income/FC_inputs output: [500, 100]
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/concat-hh_income/FC_states inp

[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/concat-hh_bikes/FC_h_outputs input: [500, 200]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/concat-hh_bikes/FC_h_outputs output: [500, 100]
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/hh_bikes/FC input: [500, 100]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/hh_bikes/FC output: [500, 50]
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/hh_bikes/FC_prob input: [500, 50]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/hh_bikes/FC_prob output: [500, 8]
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/hh_bikes/FC_input input: [500, 8]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/hh_bikes/FC_input output: [500, 100]
[32m[0112 13:11:16 @DATGANSynthesizer.py:138][0m [91mCreating cell for choice (in-edges: 7; ancestors: 13)
[32m[0112 13:11:16 @registry.py:126][0m gen/LSTM/concat-choice/FC_inputs input: [500, 700]
[32m[0112 13:11:16 @registry.py:134][0m gen/LSTM/concat-choice/FC_inputs output: [500, 100]


[32m[0112 13:11:20 @base.py:209][0m Setup callbacks graph ...







[32m[0112 13:11:20 @utils.py:26][0m Clip discrim/DISCR_FC_0/FC/W

[32m[0112 13:11:20 @utils.py:26][0m Clip discrim/DISCR_FC_0/FC/b
[32m[0112 13:11:20 @utils.py:26][0m Clip discrim/DISCR_FC_0/FC_DIVERSITY/W
[32m[0112 13:11:20 @utils.py:26][0m Clip discrim/DISCR_FC_0/FC_DIVERSITY/b
[32m[0112 13:11:20 @utils.py:26][0m Clip discrim/DISCR_FC_0/BN/beta
[32m[0112 13:11:20 @utils.py:26][0m Clip discrim/DISCR_FC_TOP/W
[32m[0112 13:11:20 @utils.py:26][0m Clip discrim/DISCR_FC_TOP/b
[32m[0112 13:11:20 @summary.py:46][0m [MovingAverageSummary] 3 operations in collection 'MOVING_SUMMARY_OPS' will be run with session hooks.
[32m[0112 13:11:20 @summary.py:93][0m Summarizing collection 'summaries' of size 4.

[32m[0112 13:11:20 @graph.py:98][0m Applying collection UPDATE_OPS of 4 ops.



[32m[0112 13:11:21 @base.py:230][0m Creating the session ...


Instructions for updating:
Please use tensorflow.python.ops.o

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
[32m[0112 13:11:26 @sessinit.py:114][0m Restoring checkpoint from ./example/output/model\model-25499 ...
INFO:tensorflow:Restoring parameters from ./example/output/model\model-25499


You can save the model to reuse is later.

In [11]:
datgan.save('trained', force=True)

[32m[0112 13:11:42 @datgan.py:519][0m Model saved successfully.


Sample the synthetic data

In [13]:
samples = datgan.sample(len(df))

|                                                                                          |17/?[00:00<00:00,69.52it/s]


If there is a discrete distribution, you need to round the values. Future work will fix this.

In [14]:
samples.age = np.round(samples.age)

In [15]:
samples.to_csv('example/data/CMAP_synthetic.csv', index=False)