In [1]:
import os
os.chdir('../..')

In [2]:
from platform import python_version
print(python_version())

3.7.9


In [3]:
import numpy as np
import pandas as pd

import tensorflow as tf

import matplotlib.pyplot as plt

from modules.datgan import DATWGAN

import networkx as nx
import json
import beepy

# For the Python notebook
%matplotlib inline
%reload_ext autoreload
%autoreload 2






In [4]:
dataset = 'Chicago'
name = 'FULLDAG'

In [5]:
df = pd.read_csv('../data/{}/data.csv'.format(dataset), index_col=False)

In [6]:
df.head()

Unnamed: 0,choice,travel_dow,trip_purpose,distance,hh_vehicles,hh_size,hh_bikes,hh_descr,hh_income,gender,age,license,education_level,work_status,departure_time
0,drive,7,HOME_OTHER,23.42579,2,2,0,2,6,1,66,1.0,6,FTE,9.333333
1,drive,7,OTHER,1.71259,2,2,0,2,6,1,66,1.0,6,FTE,12.083333
2,drive,7,HOME_OTHER,21.77887,2,2,0,2,6,1,66,1.0,6,FTE,15.5
3,drive,7,SHOPPING,2.02603,2,2,0,2,6,1,66,1.0,6,FTE,17.5
4,drive,7,SHOPPING,0.87691,2,2,0,2,6,1,66,1.0,6,FTE,18.25


In [7]:
continuous_columns = ["distance", "age", "departure_time"]

In [8]:
# personalised graph
graph = nx.DiGraph()

if name in ['FULLDAG', 'TRANSRED']:
    graph.add_edges_from([
        ("age", "license"),
        ("age", "education_level"),
        ("gender", "work_status"),
        ("education_level", "work_status"),
        ("education_level", "hh_income"),
        ("work_status", "hh_income"),
        ("hh_income", "hh_descr"),
        ("hh_income", "hh_size"),
        ("hh_size", "hh_vehicles"),
        ("hh_size", "hh_bikes"),
        ("work_status", "trip_purpose"),
        ("trip_purpose", "departure_time"),
        ("trip_purpose", "distance"),
        ("travel_dow", "choice"),
        ("distance", "choice"),
        ("departure_time", "choice"),
        ("hh_vehicles", "choice"),
        ("hh_bikes", "choice"),
        ("license", "choice"),
        # Non necessary links
        ("education_level", "hh_size"),
        ("work_status", "hh_descr"),
        ("work_status", "hh_size"),
        ("hh_income", "hh_bikes"),
        ("hh_income", "hh_vehicles"),
        ("trip_purpose", "choice")
    ])
    
    if name is 'TRANSRED':
        graph = nx.transitive_reduction(graph)
elif name is 'LINEAR':
    list_ = []
    for i in range(len(df.columns)-1):
        list_.append((df.columns[i], df.columns[i+1]))
    graph.add_edges_from(list_)
elif name is 'NOLINKS':
    for c in df.columns:
        graph.add_node(c)
elif name is 'CHOICE':
    list_ = []
    for c in df.columns:
        if c == 'choice':
            pass
        else:
            list_.append((c, 'choice'))
    graph.add_edges_from(list_)
else:
    print("PROBLEM")

In [9]:
output_folder = '../output/' + dataset + '/{}/'.format(name)

In [10]:
datgan = DATWGAN(continuous_columns, max_epoch=1000, batch_size=500, output=output_folder, gpu=0)

In [11]:
test = datgan.fit(df, graph)

[32m[1104 19:23:06 @DATSGAN.py:155][0m Found preprocessed data
[32m[1104 19:23:07 @DATSGAN.py:163][0m Preprocessed data have been loaded!



[32m[1104 19:23:07 @input_source.py:222][0m Setting up the queue 'QueueInput/input_queue' for CPU prefetching ...








Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
[32m[1104 19:23:07 @DATSGANModel.py:211][0m [91mCreating cell for age (in-edges: 0; ancestors: 0)
Instructions for updating:
Please use `layer.add_weight` method instead.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
[32m[1104 19:23:07 @registry.py:126][0m gen/LSTM/age/FC input: [500, 100]

Instructions for updating:
Please use `layer.__call__` method instead.
[32m[1104 19:23:07 @registry.py:134][0m gen/LSTM/age/FC output: [500, 100]
[32m[1104 19:23:07 @registry.py:126][0m gen/LSTM/age/FC2_val input: [500, 100]
[32

[32m[1104 19:23:07 @registry.py:126][0m gen/LSTM/work_status/FC input: [500, 100]
[32m[1104 19:23:07 @registry.py:134][0m gen/LSTM/work_status/FC output: [500, 100]
[32m[1104 19:23:07 @registry.py:126][0m gen/LSTM/work_status/FC2 input: [500, 100]
[32m[1104 19:23:07 @registry.py:134][0m gen/LSTM/work_status/FC2 output: [500, 8]
[32m[1104 19:23:07 @registry.py:126][0m gen/LSTM/work_status/FC3 input: [500, 8]
[32m[1104 19:23:07 @registry.py:134][0m gen/LSTM/work_status/FC3 output: [500, 100]
[32m[1104 19:23:07 @DATSGANModel.py:211][0m [91mCreating cell for hh_income (in-edges: 2; ancestors: 4)
[32m[1104 19:23:07 @registry.py:126][0m gen/LSTM/concat-hh_income/FC_inputs input: [500, 200]
[32m[1104 19:23:07 @registry.py:134][0m gen/LSTM/concat-hh_income/FC_inputs output: [500, 100]
[32m[1104 19:23:07 @registry.py:126][0m gen/LSTM/concat-hh_income/FC_attentions input: [500, 200]
[32m[1104 19:23:07 @registry.py:134][0m gen/LSTM/concat-hh_income/FC_attentions output: [50

[32m[1104 19:23:08 @registry.py:134][0m gen/LSTM/hh_vehicles/FC3 output: [500, 100]
[32m[1104 19:23:08 @DATSGANModel.py:211][0m [91mCreating cell for hh_bikes (in-edges: 2; ancestors: 6)
[32m[1104 19:23:08 @registry.py:126][0m gen/LSTM/concat-hh_bikes/FC_inputs input: [500, 200]
[32m[1104 19:23:08 @registry.py:134][0m gen/LSTM/concat-hh_bikes/FC_inputs output: [500, 100]
[32m[1104 19:23:08 @registry.py:126][0m gen/LSTM/concat-hh_bikes/FC_attentions input: [500, 200]
[32m[1104 19:23:08 @registry.py:134][0m gen/LSTM/concat-hh_bikes/FC_attentions output: [500, 100]
[32m[1104 19:23:08 @registry.py:126][0m gen/LSTM/concat-hh_bikes/FC_lstm_state_0 input: [500, 200]
[32m[1104 19:23:08 @registry.py:134][0m gen/LSTM/concat-hh_bikes/FC_lstm_state_0 output: [500, 100]
[32m[1104 19:23:08 @registry.py:126][0m gen/LSTM/concat-hh_bikes/FC_lstm_state_1 input: [500, 200]
[32m[1104 19:23:08 @registry.py:134][0m gen/LSTM/concat-hh_bikes/FC_lstm_state_1 output: [500, 100]
[32m[1104 1

[32m[1104 19:23:11 @base.py:209][0m Setup callbacks graph ...






[32m[1104 19:23:12 @summary.py:46][0m [MovingAverageSummary] 3 operations in collection 'MOVING_SUMMARY_OPS' will be run with session hooks.
[32m[1104 19:23:12 @summary.py:93][0m Summarizing collection 'summaries' of size 4.

[32m[1104 19:23:12 @graph.py:98][0m Applying collection UPDATE_OPS of 4 ops.

[32m[1104 19:23:13 @base.py:230][0m Creating the session ...


Instructions for updating:
Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.



[32m[1104 19:23:14 @base.py:236][0m Initializing the session ...
[32m[1104 19:23:14 @base.py:243][0m Graph Finalized.


[32m[1104 19:23:14 @concurrency.py:38][0m Starting EnqueueThread QueueInput/input_queue ...

[32m[1104 19:23:15 @base.py:275][0m Start Epoch 1 ...


  2%|#4                                                                                   |3/175[00:02<02:09, 1.33it/s]

[32m[1104 19:23:17 @base.py:293][0m Detected Ctrl-C and exiting main loop.


  2%|#9                                                                                   |4/175[00:02<01:38, 1.73it/s]


[32m[1104 19:23:17 @input_source.py:178][0m EnqueueThread QueueInput/input_queue Exited.


KeyboardInterrupt: 

In [None]:
datgan.save('trained', force=True)

In [None]:
beepy.beep(6)