# GAN Model

The purpose of this notebook is to attempt to create, and then manipulate, platoon data. The method to accomplish this is tentitivley set to be in the form of a General Adversarial Network (GAN). 

In [1]:
import pandas as pd

# Preprocessing

This section takes the platoon data and creates features of equal length.

In [None]:
df = pd.read_pickle('CeneriData/2003_2019_platoon_h2.6_lane4.pkl')

In [None]:
df[df.Length > 1].Length.hist(bins=10)

In [None]:
df = df.drop(columns='Lane')

In [None]:
df.Length.value_counts()

In [None]:
df = df.drop(columns='Platoon')

In [None]:
df[df.CLASS == 1].CLASS.value_counts()

In [None]:
#df = df.drop(columns=['AX_W', 'AX_DIST']) #After expanding the features to be the same length, keeping these columns is too large

In [None]:
df.columns

In [None]:
df = df[df.Length < 10]

In [None]:
expand_list = ['CLASS', 'GW_TOT', 'LENTH', 'IVT', 'SPEED', 'AX']

In [None]:
dfs = []

In [None]:
df = df.reset_index(drop=True)

In [None]:
for expand in expand_list:
    df_temp = pd.DataFrame(df[expand].values.tolist())
    columns_tmp = []
    for i in range(1, 10):
        columns_tmp.append('{}_{}'.format(expand, i))
    df_temp.columns = columns_tmp
    dfs.append(df_temp)

In [None]:
ax_list = ['AX_W', 'AX_DIST']

In [None]:
df_small = df.drop(columns=expand_list)

In [None]:
df_small = df_small.drop(columns=ax_list)

In [None]:
dfs[0][dfs[0] == 0] = 99 #Replaces the zero with a 99 category

In [None]:
dfs[0]

In [None]:
dfs.append(df_small)

In [None]:
df_cat = pd.concat(dfs, axis= 1)

In [None]:
df_cat = df_cat.fillna(0)

In [None]:
df_ax =[]

In [None]:
for expand in ax_list:
    df_temp = pd.DataFrame(df[expand].values.tolist())
    columns_tmp = []
    for i in range(1, 10):
        columns_tmp.append('{}_{}'.format(expand, i))
    df_temp.columns = columns_tmp
    df_ax.append(df_temp)

In [None]:
zero_list = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [None]:
msk = df_ax[0].isna()

In [None]:
df_ax[0] = df_ax[0].where(~msk, other=pd.Series([zero_list]*df.shape[0]), axis=0)

In [None]:
msk = df_ax[1].isna()

df_ax[1] = df_ax[1].where(~msk, other=pd.Series([zero_list[:-1]]*df.shape[0]), axis=0)

In [None]:
df_ax = pd.concat(df_ax, axis= 1)

In [None]:
df_cat = pd.concat([df_ax, df_cat], axis= 1)

In [None]:
df_cat.AX_DIST_2

Finally, the Start and End variables will be replaced with day of week and time of day

In [None]:
df_cat['Weekday'] = df_cat.Start.dt.dayofweek

In [None]:
df_cat['Hour'] = df_cat.Start.dt.hour

In [None]:
df_cat = df_cat.drop(columns=['Start', 'End'])

In [None]:
df_cat.isna().sum()

In [None]:
pd.to_pickle(df_cat,'CeneriData/cleaned_2003_2019_platoon.pkl')

# Not....

This section will be for auto-encoding the non-continuous variables such as the Class, weekday and hour variables.

In [None]:
df_cat

# CTGAN

This section will test to see if the PATE GAN can be applied to our data.

In [2]:
df = pd.read_pickle('CeneriData/cleaned_2003_2019_platoon.pkl')

In [None]:
old_ax = []
for ax in ax_list:
    for i in range(1, 10):
        old_ax.append('{}_{}'.format(ax, i))

In [None]:
old_ax

In [None]:
df_cat = df_cat.drop(columns=old_ax)

In [None]:
pd.to_pickle(df_cat,'CeneriData/cleaned_2003_2019_platoon_fullax.zip')

In [None]:
df = pd.read_pickle('CeneriData/cleaned_2003_2019_platoon_fullax.zip')

In [None]:
discrete_columns = ['CLASS_1', 'CLASS_2','CLASS_3','CLASS_4','CLASS_5','CLASS_6','CLASS_7','CLASS_8', 'CLASS_9', 'Length',
                    'Weekday','Hour','AX_1','AX_2','AX_3','AX_4','AX_5','AX_6','AX_7','AX_8','AX_9']

In [3]:
from ctgan import CTGANSynthesizer



In [None]:
ctgan = CTGANSynthesizer()

In [4]:
ax_cols = []
for i in range(1, 10):
    ax_cols.append('{}_{}'.format('AX_W', i))
    #for j in range(0, 10):
        #ax_cols.append('{}_{}_{}'.format('AX_W', i, j))

In [5]:
for i in range(1, 10):
    ax_cols.append('{}_{}'.format('AX_DIST', i))
    #for j in range(0, 9):
        #ax_cols.append('{}_{}_{}'.format('AX_DIST', i, j))

In [6]:
df_noax = df.drop(columns = ax_cols)

In [7]:
no_col = []
for col in df_noax.columns:
    if col[-1].isdigit() and int(col[-1]) > 5:
        no_col.append(col)

In [8]:
df_sm = df_noax.drop(columns=no_col)

In [9]:
df_sm = df_sm[df_sm.Length <= 5]

In [10]:
import random

In [11]:
df_sm = df_sm.astype(int)

In [None]:
ind_red = random.sample(range(0, df_sm.shape[0]), round(df_sm.shape[0]*.05))

In [None]:
df_small = df_sm.loc[ind_red]

In [None]:
df_small.columns

In [None]:
ctgan.fit(df_small, discrete_columns, epochs=20)

In [None]:
samples = ctgan.sample(1000)

In [None]:
ctgan.save('ctgan_epoch20.pkl')

In [None]:
samples

Below is an attempt to parallelize different sizes of the dataset being used with the CTGAN

In [17]:
def parallelize_ctgan(df=None, func=None, n_cores=20, lengths=[]):
    df_split = []
    for length in lengths:
        tmp_df = df[df.Length == length]
        #Find all the columns that only have one value in this df and drop these columns
        nunique = tmp_df.apply(pd.Series.nunique)
        cols_to_drop = nunique[nunique == 1].index
        tmp_df = tmp_df.drop(cols_to_drop, axis=1)
        df_split.append(tmp_df)
    ctx = mp.get_context('spawn')
    pool = ctx.Pool(n_cores)
    pool.starmap(func, zip(df_split, lengths))
    pool.close()
    pool.join()

In [13]:
def ctgan_50(df, length):
    from ctgan import CTGANSynthesizer
    ctgan_50 = CTGANSynthesizer()
    discrete_columns = ['CLASS_1', 'CLASS_2','CLASS_3','CLASS_4','CLASS_5','CLASS_6','CLASS_7','CLASS_8', 'CLASS_9', 'Length',
                    'Weekday','Hour','AX_1','AX_2','AX_3','AX_4','AX_5','AX_6','AX_7','AX_8','AX_9']
    #Keep only discrete columns that are kept in this dataframe
    tmp_discrete_columns = list(set(discrete_columns).intersection(list(df.columns)))
    print('Starting {} length fit'.format(length))
    ctgan_50.fit(df, tmp_discrete_columns, epochs=50)
    ctgan_50.save('ctgan_length{}_epoch50.pkl'.format(length))

In [14]:
import multiprocess as mp

In [None]:
#Take df_sm and parallelize it with different amounts of the 
parallelize_ctgan(df_sm, ctgan_50, n_cores=10, lengths=lengths)

In [15]:
lengths = list(df_sm.Length.unique())