In [None]:
import datetime
import json
import time
from pathlib import Path

import numpy as np
import pandas as pd
from sdv.metadata import Metadata
from sdv.single_table import CTGANSynthesizer

PROJECT_ROOT = Path(__name__).resolve().parent.parent.parent
INPUT_FOLDER = PROJECT_ROOT / "data/input"
OUTPUT_FOLDER = PROJECT_ROOT / "data/output"
OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)

In [None]:
# UCI wine quality data
# Read data and check
ifolder = INPUT_FOLDER / "UCI_winequality"
ofolder = OUTPUT_FOLDER / "UCI_winequality"
ofolder.mkdir(parents=True, exist_ok=True)
data_path = ifolder / "winequality-red.csv"
df1 = pd.read_csv(data_path, delimiter=";")
df1["winetype"] = "red"
data_path = ifolder / "winequality-white.csv"
df2 = pd.read_csv(data_path, delimiter=";")
df1["winetype"] = "white"
df = pd.concat([df1, df2])
print(df.head())
print(df.shape)

# set up metadata for GAN
df_meta = Metadata.detect_from_dataframe(df)
gen = CTGANSynthesizer(
    metadata=df_meta,
    epochs=500,
    verbose=True,
)

In [None]:
# Train the GAN - keep track of the time to execute
# Check if GAN exists - these take a while to fit, so only refit if necessary
duration = None
pkl = ofolder / "ctgan.pkl"
pkl_exists = Path(pkl).is_file()
if pkl_exists:
    gen.load(pkl)
else:
    # Fit a new model if it doesn't exist
    tstart = time.time()
    gen.fit(df)
    tend = time.time()
    duration = tend - tstart

In [None]:
# Save the results, plot the loss function, and print the time to train the GAN
if not pkl_exists:
    gen.save(pkl)
md = ofolder / "ctgan_metadata.json"
if not md.is_file():
    df_meta.save_to_json(md)
df.to_pickle(ofolder / "real_df.pkl")

# Units are seconds, so display minutes
print(f"Time to fit: {(duration / 60):.2f} min.")

In [None]:
# do again, but train for different epochs/batches and save to evaluate later
epochs = [1, 5, 10, 25, 50, 100, 250, 500, 1000]
batches = [20, 50, 100, 500, 1000] # need to be in increments of 10
for epoch in epochs:
    for batch in batches:
        print(f"Currently fitting: epoch {epoch} batch {batch}.")
        grid_saveto = ofolder / f"grid/epoch {epoch} batch {batch}"
        grid_saveto.mkdir(parents=True, exist_ok=True)
        grid_gen = CTGANSynthesizer(
            metadata=df_meta,
            epochs=epoch,
            batch_size=batch,
            verbose=True,
        )
        tstart = time.time()
        grid_gen.fit(df)
        tend = time.time()
        duration = tend - tstart
        grid_gen.save(grid_saveto / "ctgan.pkl")
        # units are seconds, so display minutes
        print(f"Time to fit: {(duration / 60):.2f} min.")