In [95]:
import os
import re
import pickle
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [96]:

if 'google.colab' in str(get_ipython()):
  from google.colab import drive
  drive.mount('/content/drive')
  proj_dir = "/content/drive/MyDrive/ece884_project/"
else:
  proj_dir = ""

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [97]:

models = os.listdir(f"{proj_dir}logs/")
model_number = [int(re.sub("logGAN", "", x)) for x in models]
last_model = max(model_number)-2

In [98]:
df = pd.read_csv(f"{proj_dir}data_clean/taxi.csv")
column_names = df.columns
df = df.to_numpy()
from sklearn import preprocessing
scaler = preprocessing.StandardScaler().fit(df) 

In [99]:
with open(f"{proj_dir}logs/logGAN{last_model}", "rb") as fp:
    results = pickle.load(fp) 
    
with open(f"{proj_dir}saved_models/list_of_models/gen/generators{last_model}", "rb") as fp:
    generators_saved = pickle.load(fp)

with open(f"{proj_dir}saved_models/list_of_models/disc/discriminators{last_model}", "rb") as fp:
    discriminators_saved = pickle.load(fp)

In [100]:
def generated_data_filter(gen, desc, points_to_gen, threashold, dims):
    """
    inputs
    gen, is the list of gans we wrote with the gan.ipynb

    desc, is the list of discriminators in the notebook gan.ipynb
    
    points_to_gen, number of datapoints for each model to generate

    threashold, is what is the discriminator's predicted probability of the data being real
    we need to see to keep the data. 
    with a threashold = 0.99 we will drop every datapoint that the discriminator says has a 
    less than .99 change of being real. 
    we will need to play with this.

    """
    n_col = dims[1]
    quality_data = np.empty((0, n_col), np.float32)

    for generator, discriminator in zip(gen, desc):
        noise = tf.random.normal(shape=(points_to_gen, n_col), 
                                 mean=tf.random.uniform((1,), minval=-0.5, maxval=0.5),
                                 stddev=.2)
        generated_data = generator(noise)
        judgement = discriminator(generated_data) # probs data is real
        data_fooling_discriminator = np.compress(np.ravel(judgement) > threashold, generated_data, axis=0)

        quality_data = np.append(quality_data, data_fooling_discriminator, axis=0)

    return quality_data

In [101]:
generated_dataset = generated_data_filter(generators_saved, discriminators_saved, points_to_gen=1, threashold=0.99, dims=df.shape)

In [102]:
generated_data = pd.DataFrame(scaler.inverse_transform(generated_dataset), columns=column_names)

In [103]:
generated_data

Unnamed: 0,year_pickup_datetime,quarter_pickup_datetime,month_pickup_datetime,dayofweek_pickup_datetime,dayofyear_pickup_datetime,dayofmonth_pickup_datetime,weekofyear_pickup_datetime,hour_pickup_datetime,minute_pickup_datetime,second_pickup_datetime,...,hour_dropoff_datetime,minute_dropoff_datetime,second_dropoff_datetime,horizon_dropoff_datetime,passenger_count,pickup_longitude,pickup_latitude,dropoff_longitude,dropoff_latitude,trip_duration
0,2016.0,2.003502,3.61864,5.004414,143.40213,24.20715,17.768269,13.607066,29.590158,29.473591,...,13.599205,47.048653,46.829582,12348015.0,2.978771,-73.973488,40.750996,-73.973419,40.751801,959.492249
1,2016.0,2.003502,3.617143,5.004414,143.40213,24.20715,17.653955,13.607094,29.590158,29.473591,...,13.599216,47.048653,46.829582,12347743.0,2.978771,-73.973488,40.750996,-73.973419,40.751801,959.492249
2,2016.0,2.003502,3.617446,5.004414,143.40213,24.20715,17.626736,13.60711,29.590158,29.473591,...,13.599224,47.048653,46.829582,12347723.0,2.978771,-73.973488,40.750996,-73.973419,40.751801,959.492249
3,2016.0,2.003502,3.615467,5.004414,143.40213,24.20715,17.554825,13.607078,29.590158,29.473591,...,13.599204,47.048653,46.829582,12347552.0,2.978771,-73.973488,40.750999,-73.973419,40.751801,959.492249
4,2016.0,2.003502,3.617052,5.004414,143.40213,24.20715,17.652096,13.607076,29.590158,29.473591,...,13.599206,47.048653,46.829582,12347764.0,2.978771,-73.973488,40.750999,-73.973419,40.751801,959.492249
5,2016.0,2.003502,3.613798,5.004414,143.40213,24.20715,17.470144,13.607032,29.590158,29.473591,...,13.599174,47.048653,46.829582,12347426.0,2.978771,-73.973488,40.751003,-73.973419,40.751801,959.492249
6,2016.0,2.003502,3.614933,5.004414,143.40213,24.20715,17.606894,13.607017,29.590158,29.473591,...,13.599173,47.048653,46.829582,12347655.0,2.978771,-73.973488,40.751003,-73.973419,40.751801,959.492249
7,2016.0,2.003502,3.618314,5.004414,143.40213,24.20715,17.777763,13.607054,29.590158,29.473591,...,13.599199,47.048653,46.829582,12347965.0,2.978771,-73.973488,40.750996,-73.973419,40.751801,959.492249
8,2016.0,2.003502,3.616611,5.004414,143.40213,24.20715,17.672615,13.607053,29.590158,29.473591,...,13.599194,47.048653,46.829582,12347836.0,2.978771,-73.973488,40.750999,-73.973419,40.751801,959.492249
9,2016.0,2.003502,3.620822,5.004414,143.40213,24.20715,17.903833,13.607049,29.590158,29.473591,...,13.599198,47.048653,46.829582,12348354.0,2.978771,-73.973488,40.750996,-73.973419,40.751801,959.492249


In [104]:
generated_data.std(axis=0)

year_pickup_datetime           0.000000e+00
quarter_pickup_datetime        0.000000e+00
month_pickup_datetime          1.942655e-03
dayofweek_pickup_datetime      2.890070e-06
dayofyear_pickup_datetime      7.706852e-05
dayofmonth_pickup_datetime     1.926713e-06
weekofyear_pickup_datetime     1.197101e-01
hour_pickup_datetime           2.109421e-05
minute_pickup_datetime         1.348699e-05
second_pickup_datetime         7.706853e-06
horizon_pickup_datetime        4.040610e+00
year_dropoff_datetime          0.000000e+00
quarter_dropoff_datetime       9.633566e-07
month_dropoff_datetime         0.000000e+00
dayofweek_dropoff_datetime     2.408391e-06
dayofyear_dropoff_datetime     3.082741e-05
dayofmonth_dropoff_datetime    1.926713e-06
weekofyear_dropoff_datetime    1.475678e-03
hour_dropoff_datetime          1.302207e-05
minute_dropoff_datetime        1.541371e-05
second_dropoff_datetime        3.853426e-06
horizon_dropoff_datetime       2.611704e+02
passenger_count                1

In [105]:
gan20 == generated_data.iloc[0, :]

year_pickup_datetime            True
quarter_pickup_datetime        False
month_pickup_datetime          False
dayofweek_pickup_datetime      False
dayofyear_pickup_datetime      False
dayofmonth_pickup_datetime     False
weekofyear_pickup_datetime     False
hour_pickup_datetime           False
minute_pickup_datetime          True
second_pickup_datetime         False
horizon_pickup_datetime        False
year_dropoff_datetime           True
quarter_dropoff_datetime       False
month_dropoff_datetime         False
dayofweek_dropoff_datetime     False
dayofyear_dropoff_datetime     False
dayofmonth_dropoff_datetime    False
weekofyear_dropoff_datetime    False
hour_dropoff_datetime          False
minute_dropoff_datetime         True
second_dropoff_datetime         True
horizon_dropoff_datetime       False
passenger_count                False
pickup_longitude                True
pickup_latitude                False
dropoff_longitude               True
dropoff_latitude                True
t

In [106]:
gan20

year_pickup_datetime           2.016000e+03
quarter_pickup_datetime        1.996221e+00
month_pickup_datetime          5.192637e+00
dayofweek_pickup_datetime      3.050375e+00
dayofyear_pickup_datetime      9.183853e+01
dayofmonth_pickup_datetime     1.550402e+01
weekofyear_pickup_datetime     1.399747e+01
hour_pickup_datetime           2.000617e+01
minute_pickup_datetime         2.959016e+01
second_pickup_datetime         4.679343e+01
horizon_pickup_datetime        1.202994e+07
year_dropoff_datetime          2.016000e+03
quarter_dropoff_datetime       2.002290e+00
month_dropoff_datetime         3.526658e+00
dayofweek_dropoff_datetime     3.054422e+00
dayofyear_dropoff_datetime     9.184747e+01
dayofmonth_dropoff_datetime    1.550449e+01
weekofyear_dropoff_datetime    1.387696e+01
hour_dropoff_datetime          2.008258e+01
minute_dropoff_datetime        4.704865e+01
second_dropoff_datetime        4.682958e+01
horizon_dropoff_datetime       1.027407e+07
passenger_count                1