In [7]:
import datetime
import json
import time
from pathlib import Path

import numpy as np
import pandas as pd
from sdv.metadata import Metadata
from sdv.single_table import CTGANSynthesizer

PROJECT_ROOT = Path(__name__).resolve().parent.parent.parent
INPUT_FOLDER = PROJECT_ROOT / "data/input"
OUTPUT_FOLDER = PROJECT_ROOT / "data/output"
OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)

In [8]:
# UCI Adult data
# Read data and check
ifolder = INPUT_FOLDER / "UCI_adult"
ofolder = OUTPUT_FOLDER / "UCI_adult"
ofolder.mkdir(parents=True, exist_ok=True)
data_path = ifolder / "adult.data"
colnames = [
    "age",
    "workclass",
    "fnlwgt",
    "education",
    "education_num",
    "marital_status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "capital_gain",
    "capital_loss",
    "minutes_per_week",
    "native_country",
    "Income_Category",
]
df = pd.read_csv(data_path, names=colnames)
print(df.head())
print(df.shape)

# Split into modeling and validation
np.random.seed(100)
df["rand10"] = np.random.randint(1, 10, len(df))
df_model = df[df["rand10"] <= 8]
df_valid = df[df["rand10"] >= 9]

# Set up metadata for GAN
df_meta = Metadata.detect_from_dataframe(df)
gen = CTGANSynthesizer(
    metadata=df_meta,
    epochs=500,
    verbose=True,
)

   age          workclass  fnlwgt   education  education_num  \
0   39          State-gov   77516   Bachelors             13   
1   50   Self-emp-not-inc   83311   Bachelors             13   
2   38            Private  215646     HS-grad              9   
3   53            Private  234721        11th              7   
4   28            Private  338409   Bachelors             13   

        marital_status          occupation    relationship    race      sex  \
0        Never-married        Adm-clerical   Not-in-family   White     Male   
1   Married-civ-spouse     Exec-managerial         Husband   White     Male   
2             Divorced   Handlers-cleaners   Not-in-family   White     Male   
3   Married-civ-spouse   Handlers-cleaners         Husband   Black     Male   
4   Married-civ-spouse      Prof-specialty            Wife   Black   Female   

   capital_gain  capital_loss  minutes_per_week  native_country  \
0          2174             0                40   United-States   
1     


We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.



In [None]:
# Train the GAN - keep track of the time to execute
# Check if GAN exists - these take a while to fit, so only refit if necessary
duration = None
pkl_exists = Path(ofolder / "ctgan.pkl").is_file()
if pkl_exists:
    gen.load(ofolder / "ctgan.pkl")
else:
    # Fit a new model if it doesn't exist
    tstart = time.time()
    gen.fit(df_model)
    tend = time.time()
    duration = tend - tstart

In [None]:
# Save the results, plot the loss function, and print the time to train the GAN
if not pkl_exists:
    gen.save(ofolder / "ctgan.pkl")
df_meta.save_to_json(ofolder / "ctgan_metadata.json")
df.to_pickle(ofolder / "real_df.pkl")

# Units are seconds, so display minutes
print(f"Time to fit: {(duration / 60):.2f} min.")

TypeError: unsupported operand type(s) for /: 'NoneType' and 'int'