In [1]:
import os
os.chdir('../..')

In [2]:
from platform import python_version
print(python_version())

3.7.9


In [3]:
import chevron
import sys
import os
import copy
import logging
import tensorflow as tf
import torch
import matplotlib.pyplot as plt

from distutils.dir_util import copy_tree

import pandas as pd
import networkx as nx

import time

# For the Python notebook
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [4]:
init_modules = copy.deepcopy(list(sys.modules.keys()))
tf.logging.set_verbosity(tf.logging.ERROR)

In [5]:
class redirect_output(object):
    """context manager for reditrecting stdout/err to files"""

    def __init__(self, out=''):
        self.log = open(out, 'w')
        
        self.old_stdout = sys.stdout
        self.old_stderr = sys.stderr
                
    def __enter__(self):
        sys.stdout = self.log
        sys.stderr = self.log
        
    def __exit__(self, exc_type, exc_value, traceback):
        sys.stdout = self.old_stdout
        sys.stderr = self.old_stderr
        self.log.close()

In [6]:
def delete_loaded_modules(init_modules):
    to_del = []
    for m in sys.modules.keys(  ):
        if m not in init_modules:
            to_del.append(m)
        
    for m in to_del:
        del(sys.modules[m])

In [7]:
def time_to_str(elapsed):
    
    hours, rem = divmod(elapsed, 3600)
    minutes, seconds = divmod(rem, 60)
    str_ = "{:0>2} seconds".format(int(seconds))
    if minutes > 0 or hours > 0:
        str_ = "{:0>1} minutes and ".format(int(minutes)) + str_
    if hours > 0:
        str_ = "{:0>1} hours ".format(int(hours)) + str_
        
    return str_

In [8]:
def check_already_trained(dataset, name):
    
    return os.path.isfile('../output/{}/{}/trained.tar.gz'.format(dataset, name)) or os.path.isfile('../output/{}/{}/trained.pickle'.format(dataset, name))

In [9]:
def dag(dataset, type_, df):

    # personalised graph
    graph = nx.DiGraph()
    
    if type_ == 'FULL' or type_ == 'TRANSRED':

        if dataset is 'Chicago':
            graph.add_edges_from([
                ("age", "license"),
                ("age", "education_level"),
                ("gender", "work_status"),
                ("education_level", "work_status"),
                ("education_level", "hh_income"),
                ("work_status", "hh_income"),
                ("hh_income", "hh_descr"),
                ("hh_income", "hh_size"),
                ("hh_size", "hh_vehicles"),
                ("hh_size", "hh_bikes"),
                ("work_status", "trip_purpose"),
                ("trip_purpose", "departure_time"),
                ("trip_purpose", "distance"),
                ("travel_dow", "choice"),
                ("distance", "choice"),
                ("departure_time", "choice"),
                ("hh_vehicles", "choice"),
                ("hh_bikes", "choice"),
                ("license", "choice"),
                # Non necessary links
                ("education_level", "hh_size"),
                ("work_status", "hh_descr"),
                ("work_status", "hh_size"),
                ("hh_income", "hh_bikes"),
                ("hh_income", "hh_vehicles"),
                ("trip_purpose", "choice")
            ])
        elif dataset is 'LPMC':
            graph.add_edges_from([
                ("travel_year", "travel_month"),
                ("travel_date", "day_of_week"),
                ("travel_month", "travel_date"),
                ("travel_month", "driving_traffic_percent"),
                ("travel_month", "day_of_week"),
                ("travel_month", "travel_mode"),
                ("travel_date", "day_of_week"),
                ("day_of_week", "driving_traffic_percent"),
                ("day_of_week", "cost_driving_con_charge"),
                ("day_of_week", "purpose"),
                ("day_of_week", "start_time_linear"),
                ("day_of_week", "travel_mode"),
                ("purpose", "distance"),
                ("purpose", "start_time_linear"),
                ("purpose", "travel_mode"),
                ("start_time_linear", "driving_traffic_percent"),
                ("start_time_linear", "cost_driving_con_charge"),
                ("start_time_linear", "travel_mode"),
                ("car_ownership", "fueltype"),
                ("car_ownership", "driving_license"),
                ("car_ownership", "travel_mode"),
                ("fueltype", "cost_driving_con_charge"),
                ("fueltype", "cost_driving_fuel"),
                ("female", "driving_license"),
                ("female", "travel_mode"),
                ("age", "bus_scale"),
                ("age", "driving_license"),
                ("age", "faretype"),
                ("age", "travel_mode"),
                ("driving_license", "travel_mode"),
                ("faretype", "cost_transit"),
                ("faretype", "bus_scale"),
                ("faretype", "travel_mode"),
                ("bus_scale", "cost_transit"),
                ("distance", "cost_driving_fuel"),
                ("distance", "dur_driving"),
                ("distance", "dur_walking"),
                ("distance", "dur_cycling"),
                ("distance", "dur_pt_access"),
                ("distance", "dur_pt_rail"),
                ("distance", "dur_pt_bus"),
                ("distance", "dur_pt_int"),
                ("distance", "pt_n_interchanges"),
                ("distance", "travel_mode"),
                ("pt_n_interchanges", "dur_pt_rail"),
                ("pt_n_interchanges", "dur_pt_bus"),
                ("pt_n_interchanges", "dur_pt_int"),
                ("pt_n_interchanges", "cost_transit"),
                ("driving_traffic_percent", "cost_driving_con_charge"),
                ("driving_traffic_percent", "travel_mode"),
                ("cost_driving_fuel", "cost_driving_con_charge"),
                ("cost_driving_fuel", "travel_mode"),
                ("cost_driving_con_charge", "travel_mode"),
                ("dur_driving", "travel_mode"),
                ("dur_walking", "travel_mode"),
                ("dur_cycling", "travel_mode"),
                ("dur_pt_access", "travel_mode"),
                ("dur_pt_rail", "cost_transit"),
                ("dur_pt_rail", "travel_mode"),
                ("dur_pt_bus", "cost_transit"),
                ("dur_pt_bus", "travel_mode"),
                ("dur_pt_int", "travel_mode"),
                ("cost_transit", "travel_mode")
            ])
        
        if type_ == 'TRANSRED':
            graph = nx.transitive_reduction(graph)
            
    elif type_ == 'LINEAR':
        list_ = []
        for i in range(len(df.columns)-1):
            list_.append((df.columns[i], df.columns[i+1]))
        graph.add_edges_from(list_)
        
    elif type_ == 'NOLINKS':
        for c in df.columns:
            graph.add_node(c)
    elif type_ == 'PREDICTION':
        
        to_pred = None
        if dataset == 'Chicago':
            to_pred = 'choice'
        elif dataset == 'LPMC':
            to_pred = 'travel_mode'
            
        list_ = []
        for c in df.columns:
            if c == to_pred:
                pass
            else:
                list_.append((c, to_pred))
        graph.add_edges_from(list_)

        
    return graph

In [10]:
def train_DATGAN(dataset, name):
    
    df = pd.read_csv('../data/{}/data.csv'.format(dataset), index_col=False)

    if dataset is 'Chicago':
        continuous_columns = ["distance", "age", "departure_time"]
    elif dataset is 'LPMC':
        continuous_columns = ['start_time_linear', 'age', 'distance', 'dur_walking', 'dur_cycling', 'dur_pt_access',
                              'dur_pt_rail', 'dur_pt_bus', 'dur_pt_int', 'dur_driving', 'cost_transit',
                              'cost_driving_fuel', 'driving_traffic_percent']
        
    if dataset == 'Chicago':
        from modules.datgan import DATWGAN as LIB
    elif dataset == 'LPMC':
        from modules.datgan import DATWGANGP as LIB

    output_folder = '../output/{}_DAG/{}/'.format(dataset, name)

    datgan = LIB(continuous_columns, max_epoch=1000, batch_size=500, 
                 output=output_folder, gpu=0)

    datgan.fit(df, dag(dataset, name.split('_')[0], df))

    datgan.save('trained', force=True)

In [11]:
dataset = 'LPMC'
n_models = 5
reuse_data = False

models = ['FULL', 'TRANSRED', 'LINEAR', 'NOLINKS', 'PREDICTION']
models = ['FULL']

if n_models > 1:
    tmp = []
    for i in range(n_models):
        for m in models:
            tmp.append(m + '_{:0>2d}'.format(i+1))

    tmp.sort()
    models = tmp

In [12]:
for i, m in enumerate(models):
    
    if check_already_trained(dataset, m):
        print("Model \033[1m{}\033[0m ({}/{}) has already been trained.".format(m, i+1, len(models)))
        
    else:
    
        print("\rTraining model \033[1m{}\033[0m ({}/{}) ... ".format(m, i+1, len(models)), end="")
        
        # Cannot delete tensorflow modules sadly =(
        tf.reset_default_graph()

        delete_loaded_modules(init_modules)

        if reuse_data and is_a_DATGAN(m):
            copy_tree('../output/{}/{}/data'.format(dataset, 'WGAN_WI'), '../output/{}/{}/data'.format(dataset, m))

        start_time = time.time()

        with redirect_output('training.log'):
            train_DATGAN(dataset, m)

        elapsed = time.time() - start_time

        time_taken = time_to_str(elapsed)

        print("Done in {}.".format(time_taken))

        for handler in logging.getLogger('tensorpack').handlers:
            handler.close()

        logging.getLogger('tensorpack').handlers = []

print("\033[1m FINISHED!\033[0m")

Training model [1mFULL_01[0m (1/5) ... Done in 31 minutes and 00 seconds.
Training model [1mFULL_02[0m (2/5) ... Done in 30 minutes and 53 seconds.
Training model [1mFULL_03[0m (3/5) ... Done in 31 minutes and 30 seconds.
Training model [1mFULL_04[0m (4/5) ... Done in 32 minutes and 29 seconds.
Training model [1mFULL_05[0m (5/5) ... Done in 42 minutes and 42 seconds.
[1m FINISHED![0m
