In [1]:
from datgan import DATGAN
import datgan

import numpy as np
import pandas as pd
import networkx as nx

# For the Python notebook
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import tensorflow as tf
#tf.config.run_functions_eagerly(True)

In [11]:
df = pd.read_csv('../../data/LPMC/trips_small_bias.csv', index_col=False)

In [12]:
len(df)

10520

In [5]:
df.head()

Unnamed: 0,travel_mode,purpose,faretype,day_of_week,start_time_linear,age,female,driving_license,distance,dur_walking,dur_cycling,dur_driving,driving_traffic_percent,hh_vehicles,hh_income,hh_people,dur_pt,hh_region
0,walk,HBO,full,6,15.833333,30,0,1,2145,0.553056,0.1575,0.158333,0.473684,0,35-50k,2,0.436389,Central London
1,pt,HBO,full,7,10.0,50,1,1,1789,0.473333,0.160556,0.135,0.547325,0,15-20k,5,0.271111,Central London
2,pt,HBW,full,5,17.0,55,1,1,10036,2.411667,0.761389,0.638056,0.543317,2,75-100k,2,0.830833,South London
3,drive,HBO,full,6,16.883333,51,1,1,1531,0.423889,0.168611,0.110556,0.268844,1,>100k,2,0.295556,East London
4,pt,HBW,full,3,7.5,39,1,1,1124,0.275833,0.123611,0.081667,0.156463,1,>100k,4,0.124722,South London


In [6]:
# First, define the specificities of continuous variables
data_info = {
    'start_time_linear': {
        'type': 'continuous',
        'bounds': [0.0, 23.999],
        'discrete': False,
    },
    'age': {
        'type': 'continuous',
        'bounds': [0, 100],
        'discrete': True
    },
    'distance': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'discrete': True,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_walking': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_cycling': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_pt': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'dur_driving': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'enforce_bounds': True,
        'discrete': False,
        'apply_func': (lambda x: np.log(x+1))
    },
    'driving_traffic_percent': {
        'type': 'continuous',
        'bounds': [0, np.infty],
        'discrete': False,
    },
}

# Add the other variables as categorical
for c in df.columns:
    if c not in data_info.keys():
        data_info[c] = {'type': 'categorical'}

In [7]:
df.columns

Index(['travel_mode', 'purpose', 'faretype', 'day_of_week',
       'start_time_linear', 'age', 'female', 'driving_license', 'distance',
       'dur_walking', 'dur_cycling', 'dur_driving', 'driving_traffic_percent',
       'hh_vehicles', 'hh_income', 'hh_people', 'dur_pt', 'hh_region'],
      dtype='object')

In [8]:
# personalised graph
graph = nx.DiGraph()

graph.add_edges_from([
    ('hh_region', 'hh_people'),
    ('hh_region', 'distance'),
    ('hh_region', 'hh_income'),
    ('hh_region', 'travel_mode'),
    ('hh_income', 'hh_vehicles'),
    ('hh_people', 'hh_vehicles'),
    ('age', 'hh_people'),
    ('age', 'faretype'),
    ('age', 'driving_license'),
    ('age', 'purpose'),
    ('age', 'travel_mode'),
    ('female', 'driving_license'),
    ('female', 'hh_people'),
    ('driving_license', 'travel_mode'),
    ('hh_vehicles', 'driving_license'),
    ('hh_vehicles', 'travel_mode'),
    ('faretype', 'travel_mode'),
    ('day_of_week', 'purpose'),
    ('day_of_week', 'start_time_linear'),
    ('day_of_week', 'driving_traffic_percent'),
    ('purpose', 'start_time_linear'),
    ('purpose', 'travel_mode'),
    ('purpose', 'distance'),
    ('start_time_linear', 'driving_traffic_percent'),
    ('driving_traffic_percent', 'dur_driving'),
    ('distance', 'driving_traffic_percent'),
    ('distance', 'dur_walking'),
    ('distance', 'dur_cycling'),
    ('distance', 'dur_pt'),
    ('distance', 'dur_driving'),
    ('distance', 'travel_mode')
])

In [9]:
name = 'DATGAN'

In [10]:
output_folder = '../output/{}/'.format(name)

In [11]:
datgan = DATGAN(output=output_folder,
                loss_function='WGGP',
                batch_size=1101,
                num_epochs=1000)

In [12]:
datgan.fit(df, data_info, graph, preprocessed_data_path='../output/encoded_LPMC')

Preprocessed data have been loaded!
Start training DATGAN with the WGGP loss (12/05/2022 16:17:22).
Restored models from epoch 1000.


Training DATGAN: 0it [00:00, ?it/s]

DATGAN has finished training (12/05/2022 16:17:22) - Training time: 00 second





In [15]:
for i in range(5):
    samp = datgan.sample(len(df))
    samp.to_csv('../../data/synthetic/test/DATGAN_{:02d}.csv'.format(i+1), index=False)


Sampling from DATGAN:   0%|          | 0/16904 [00:00<?, ?it/s][A
Sampling from DATGAN:   7%|▋         | 1100/16904 [00:02<00:33, 469.94it/s][A
Sampling from DATGAN:  13%|█▎        | 2201/16904 [00:03<00:21, 696.17it/s][A
Sampling from DATGAN:  20%|█▉        | 3301/16904 [00:04<00:16, 804.23it/s][A
Sampling from DATGAN:  26%|██▌       | 4401/16904 [00:05<00:14, 860.60it/s][A
Sampling from DATGAN:  33%|███▎      | 5500/16904 [00:06<00:12, 901.30it/s][A
Sampling from DATGAN:  39%|███▉      | 6601/16904 [00:07<00:10, 938.35it/s][A
Sampling from DATGAN:  46%|████▌     | 7702/16904 [00:08<00:09, 974.41it/s][A
Sampling from DATGAN:  52%|█████▏    | 8802/16904 [00:09<00:08, 997.13it/s][A
Sampling from DATGAN:  59%|█████▊    | 9901/16904 [00:11<00:06, 1004.73it/s][A
Sampling from DATGAN:  65%|██████▌   | 11000/16904 [00:12<00:05, 1003.93it/s][A
Sampling from DATGAN:  72%|███████▏  | 12100/16904 [00:13<00:04, 1013.35it/s][A
Sampling from DATGAN:  78%|███████▊  | 13201/16904 [00:14<

# Regions

In [12]:
nbrs = {}
dct = {}

for r in df.hh_region.unique():
    tmp = pd.read_csv('../../data/nomis/{}.csv'.format(r))

    nbrs[r] = len(tmp)
    dct[r] = []

In [12]:
remaining_regions = set(df.hh_region.unique())

count = 1
while remaining_regions:

    print("Pass {} - Remaining regions: {}".format(count, len(remaining_regions)))

    samp = datgan.sample(100000)

    region_to_remove = []

    for r in remaining_regions:
        tmp = samp[samp.hh_region == r]

        if len(tmp) > nbrs[r]:
            tmp = tmp.sample(nbrs[r], replace=False)

        nbrs[r] -= len(tmp)
        dct[r].append(tmp)

        if nbrs[r] == 0:
            region_to_remove.append(r)


    for r in region_to_remove:
        remaining_regions.remove(r)

    count += 1

Pass 1 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:38<00:00, 1013.39it/s]


Pass 2 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1045.48it/s]


Pass 3 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1036.20it/s]


Pass 4 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1034.61it/s]


Pass 5 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1033.30it/s]


Pass 6 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1032.58it/s]


Pass 7 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1042.89it/s]


Pass 8 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1040.92it/s]


Pass 9 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1039.19it/s]


Pass 10 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1044.16it/s]


Pass 11 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:33<00:00, 1065.33it/s]


Pass 12 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.59it/s]


Pass 13 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1053.49it/s]


Pass 14 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1055.08it/s]


Pass 15 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1054.23it/s]


Pass 16 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1051.97it/s]


Pass 17 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1049.38it/s]


Pass 18 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1052.37it/s]


Pass 19 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1049.81it/s]


Pass 20 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1047.57it/s]


Pass 21 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1049.09it/s]


Pass 22 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1046.14it/s]


Pass 23 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1047.73it/s]


Pass 24 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1050.25it/s]


Pass 25 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1050.55it/s]


Pass 26 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1053.54it/s]


Pass 27 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1050.88it/s]


Pass 28 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1047.61it/s]


Pass 29 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1043.84it/s]


Pass 30 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1041.63it/s]


Pass 31 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:36<00:00, 1041.49it/s]


Pass 32 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1045.70it/s]


Pass 33 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1047.26it/s]


Pass 34 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1045.27it/s]


Pass 35 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1051.33it/s]


Pass 36 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1058.64it/s]


Pass 37 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1050.27it/s]


Pass 38 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1054.20it/s]


Pass 39 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1053.55it/s]


Pass 40 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1054.69it/s]


Pass 41 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1054.83it/s]


Pass 42 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1055.80it/s]


Pass 43 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1059.80it/s]


Pass 44 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.18it/s]


Pass 45 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:33<00:00, 1064.23it/s]


Pass 46 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1060.33it/s]


Pass 47 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.23it/s]


Pass 48 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.12it/s]


Pass 49 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.98it/s]


Pass 50 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.09it/s]


Pass 51 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1063.60it/s]


Pass 52 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1060.17it/s]


Pass 53 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.29it/s]


Pass 54 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1063.67it/s]


Pass 55 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1063.65it/s]


Pass 56 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:33<00:00, 1064.27it/s]


Pass 57 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1058.85it/s]


Pass 58 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1060.77it/s]


Pass 59 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.54it/s]


Pass 60 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1063.43it/s]


Pass 61 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.30it/s]


Pass 62 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.79it/s]


Pass 63 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1060.72it/s]


Pass 64 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1060.07it/s]


Pass 65 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1058.75it/s]


Pass 66 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1057.07it/s]


Pass 67 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1056.60it/s]


Pass 68 - Remaining regions: 5


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1058.08it/s]


Pass 69 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1057.22it/s]


Pass 70 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1058.34it/s]


Pass 71 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1058.05it/s]


Pass 72 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1054.95it/s]


Pass 73 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1056.03it/s]


Pass 74 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:35<00:00, 1052.60it/s]


Pass 75 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1059.19it/s]


Pass 76 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.23it/s]


Pass 77 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.54it/s]


Pass 78 - Remaining regions: 3


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:33<00:00, 1064.23it/s]


Pass 79 - Remaining regions: 2


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.78it/s]


Pass 80 - Remaining regions: 2


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1060.93it/s]


Pass 81 - Remaining regions: 2


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.61it/s]


Pass 82 - Remaining regions: 2


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1061.19it/s]


Pass 83 - Remaining regions: 1


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.20it/s]


Pass 84 - Remaining regions: 1


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.39it/s]


Pass 85 - Remaining regions: 1


Sampling from DATGAN: 100%|██████████| 100000/100000 [01:34<00:00, 1062.39it/s]


In [13]:
for r in dct.keys():
    tmp = pd.concat(dct[r])
    tmp.to_csv('../../data/synthetic/DATGAN/{}.csv'.format(r), index=False)