In [1]:
from datgan import DATGAN
import datgan

import numpy as np
import pandas as pd
import networkx as nx

# For the Python notebook
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
df = pd.read_csv('../../data/LPMC/trips.csv', index_col=False)

In [3]:
df.head()

Unnamed: 0,travel_mode,purpose,fueltype,faretype,bus_scale,travel_year,travel_month,travel_date,day_of_week,start_time_linear,...,dur_driving,cost_transit,cost_driving_fuel,cost_driving_con_charge,driving_traffic_percent,hh_vehicles,hh_borough,hh_income,hh_people,dur_pt_int
0,drive,HBO,Diesel_Car,full,1.0,2012,8,7,2,20.0,...,0.208611,1.5,0.57,0.0,0.098535,1,Bexley,35-50k,2,0.0
1,drive,HBW,Diesel_Car,full,1.0,2013,2,8,5,15.0,...,0.471944,3.0,1.62,0.0,0.354915,1,Harrow,5-10k,3,0.133333
2,pt,HBO,Average_Car,full,0.5,2014,10,8,3,14.0,...,0.238333,0.75,0.62,0.0,0.212121,0,Lambeth,50-75k,4,0.0
3,pt,HBE,Average_Car,dis,0.5,2014,3,10,1,10.5,...,0.308889,0.75,0.6,10.5,0.684353,0,Hackney,10-15k,2,0.0
4,walk,HBW,Petrol_Car,full,1.0,2013,1,24,4,16.833333,...,0.0775,1.5,0.19,0.0,0.046595,2,Lambeth,50-75k,5,0.0


In [5]:
# First, define the specificities of continuous variables
data_info = {
    'start_time_linear': {
        'type': 'continuous',
        'bounds': [0.0, 23.999],
        'discrete': False,
    },
    'age': {
        'type': 'continuous',
        'bounds': [0, 100],
        'discrete': True
    },
    'distance': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'discrete': True,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_walking': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_cycling': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_pt_access': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_pt_rail': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
    },
    'dur_pt_bus': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
    },
    'dur_pt_int': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
    },
    'dur_driving': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'cost_transit': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
    },
    'cost_driving_fuel': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'driving_traffic_percent': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'discrete': False,
    },
}

# Add the other variables as categorical
for c in df.columns:
    if c not in data_info.keys():
        data_info[c] = {'type': 'categorical'}

In [6]:
# personalised graph
graph = nx.DiGraph()

graph.add_edges_from([
    ("travel_year", "travel_month"),
    ("travel_date", "day_of_week"),
    ("travel_month", "travel_date"),
    ("travel_month", "driving_traffic_percent"),
    ("travel_month", "day_of_week"),
    ("travel_month", "travel_mode"),
    ("travel_date", "day_of_week"),
    ("day_of_week", "driving_traffic_percent"),
    ("day_of_week", "cost_driving_con_charge"),
    ("day_of_week", "purpose"),
    ("day_of_week", "start_time_linear"),
    ("day_of_week", "travel_mode"),
    ("purpose", "distance"),
    ("purpose", "start_time_linear"),
    ("purpose", "travel_mode"),
    ("start_time_linear", "driving_traffic_percent"),
    ("start_time_linear", "cost_driving_con_charge"),
    ("start_time_linear", "travel_mode"),
    ("hh_vehicles", "fueltype"),
    ("hh_vehicles", "driving_license"),
    ("hh_vehicles", "travel_mode"),
    ("fueltype", "cost_driving_con_charge"),
    ("fueltype", "cost_driving_fuel"),
    ("female", "driving_license"),
    ("female", "travel_mode"),
    ("age", "bus_scale"),
    ("age", "driving_license"),
    ("age", "faretype"),
    ("age", "travel_mode"),
    ("driving_license", "travel_mode"),
    ("faretype", "cost_transit"),
    ("faretype", "bus_scale"),
    ("faretype", "travel_mode"),
    ("bus_scale", "cost_transit"),
    ("distance", "cost_driving_fuel"),
    ("distance", "dur_driving"),
    ("distance", "dur_walking"),
    ("distance", "dur_cycling"),
    ("distance", "dur_pt_access"),
    ("distance", "dur_pt_rail"),
    ("distance", "dur_pt_bus"),
    ("distance", "dur_pt_int"),
    ("distance", "pt_n_interchanges"),
    ("distance", "travel_mode"),
    ("pt_n_interchanges", "dur_pt_rail"),
    ("pt_n_interchanges", "dur_pt_bus"),
    ("pt_n_interchanges", "dur_pt_int"),
    ("pt_n_interchanges", "cost_transit"),
    ("driving_traffic_percent", "cost_driving_con_charge"),
    ("driving_traffic_percent", "travel_mode"),
    ("cost_driving_fuel", "cost_driving_con_charge"),
    ("cost_driving_fuel", "travel_mode"),
    ("cost_driving_con_charge", "travel_mode"),
    ("dur_driving", "travel_mode"),
    ("dur_walking", "travel_mode"),
    ("dur_cycling", "travel_mode"),
    ("dur_pt_access", "travel_mode"),
    ("dur_pt_rail", "cost_transit"),
    ("dur_pt_rail", "travel_mode"),
    ("dur_pt_bus", "cost_transit"),
    ("dur_pt_bus", "travel_mode"),
    ("dur_pt_int", "travel_mode"),
    ("cost_transit", "travel_mode"),
    ("hh_borough", "hh_income"),
    ("hh_borough", "travel_mode"),
    ("hh_borough", "distance"),
    ("hh_borough", "hh_people"),
    ("hh_income", "hh_vehicles"),
    ("hh_income", "age"),
    ("hh_income", "hh_people"),
    ("hh_people", "age"),
    ("hh_people", "female"),
    ("hh_people", "hh_vehicles")
])

In [7]:
name = 'LPMC'

In [8]:
output_folder = '../output/{}/'.format(name)

In [9]:
datgan = DATGAN(output=output_folder,
                loss_function='WGGP',
                batch_size=1101,
                num_epochs=1000)

In [10]:
datgan.fit(df, data_info, graph, preprocessed_data_path='../output/encoded_LPMC')

Preprocessed data have been loaded!
Start training DATGAN with the WGGP loss (05/05/2022 11:06:51).


Training DATGAN: 100%|██████████| 1000/1000 [1:04:25<00:00,  3.87s/it]

DATGAN has finished training (05/05/2022 12:11:17) - Training time: 01 hour, 04 minutes, and 26 seconds





In [12]:
for i in range(5):
    samp = datgan.sample(len(df))
    samp.to_csv('../../data/synthetic/test/DATGAN_{:02d}.csv'.format(i+1), index=False)

Sampling from DATGAN: 100%|██████████| 16904/16904 [00:21<00:00, 801.55it/s]
Sampling from DATGAN: 100%|██████████| 16904/16904 [00:21<00:00, 770.07it/s]
Sampling from DATGAN: 100%|██████████| 16904/16904 [00:39<00:00, 429.06it/s]
Sampling from DATGAN: 100%|██████████| 16904/16904 [00:20<00:00, 814.42it/s]
Sampling from DATGAN: 100%|██████████| 16904/16904 [00:20<00:00, 813.83it/s]


In [None]:
nbrs = {}
dct = {}

for r in df.hh_borough.unique():
    tmp = pd.read_csv('../../data/nomis/{}.csv'.format(r))

    nbrs[r] = len(tmp)
    dct[r] = []

In [None]:
remaining_boroughs = set(df.hh_borough.unique())

count = 1
while remaining_boroughs:

    print("Pass {} - Remaning boroughs: {}".format(count, len(remaining_boroughs)))

    samp = datgan.sample(100000)

    borough_to_remove = []

    for r in remaining_boroughs:
        tmp = samp[samp.hh_borough == r]

        if len(tmp) > nbrs[r]:
            tmp = tmp.sample(nbrs[r], replace=False)

        nbrs[r] -= len(tmp)
        dct[r].append(tmp)

        if nbrs[r] == 0:
            borough_to_remove.append(r)


    for r in borough_to_remove:
        remaining_boroughs.remove(r)

    count += 1

In [None]:
for r in dct.keys():
    tmp = pd.concat(dct[r])
    tmp.to_csv('../../data/synthetic/DATGAN/{}.csv'.format(r), index=False)