In [2]:
import datetime
import json
import time
from pathlib import Path

import pandas as pd
from sdv.metadata import Metadata
from sdv.single_table import CTGANSynthesizer

PROJECT_ROOT = Path(__name__).resolve().parent.parent
INPUT_FOLDER = PROJECT_ROOT / "data/input"
OUTPUT_FOLDER = PROJECT_ROOT / "data/output"
OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)

In [None]:
# UCI Adult data
# Read data and check
ifolder = INPUT_FOLDER / "UCI_adult"
ofolder = OUTPUT_FOLDER / "UCI_adult"
ofolder.mkdir(parents=True, exist_ok=True)
data_path = ifolder / "adult.data"
colnames = [
    "age",
    "workclass",
    "fnlwgt",
    "education",
    "education_num",
    "marital_status",
    "occupation",
    "relationship",
    "race",
    "sex",
    "capital_gain",
    "capital_loss",
    "minutes_per_week",
    "native_country",
    "Income_Category",
]
df = pd.read_csv(data_path, names=colnames)
print(df.head())
print(df.shape)

# set up metadata for GAN
df_meta = Metadata.detect_from_dataframe(df)
gen = CTGANSynthesizer(
    metadata=df_meta,
    epochs=500,
    verbose=True,
)


   age          workclass  fnlwgt   education  education_num  \
0   39          State-gov   77516   Bachelors             13   
1   50   Self-emp-not-inc   83311   Bachelors             13   
2   38            Private  215646     HS-grad              9   
3   53            Private  234721        11th              7   
4   28            Private  338409   Bachelors             13   

        marital_status          occupation    relationship    race      sex  \
0        Never-married        Adm-clerical   Not-in-family   White     Male   
1   Married-civ-spouse     Exec-managerial         Husband   White     Male   
2             Divorced   Handlers-cleaners   Not-in-family   White     Male   
3   Married-civ-spouse   Handlers-cleaners         Husband   Black     Male   
4   Married-civ-spouse      Prof-specialty            Wife   Black   Female   

   capital_gain  capital_loss  minutes_per_week  native_country  \
0          2174             0                40   United-States   
1     


We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.



In [4]:
# Train the GAN - keep track of the time to execute
tstart = time.time()
gen.fit(df)
tend = time.time()
duration = tend - tstart


CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)

Gen. (-0.65) | Discrim. (-0.08): 100%|██████████| 500/500 [26:02<00:00,  3.13s/it]


In [None]:
# save the results, plot the loss function, and print the time to train the GAN
gen.save(ofolder / "ctgan.pkl")
df_meta.save_to_json(ofolder / "ctgan_metdata.json")
df.to_pickle(ofolder / "real_df.pkl")
# units are seconds, so display minutes
print(f"Time to fit: {(duration / 60):.2f} min.")

Time to fit: 26.56 min.


In [None]:
# UCI wine quality data
# Read data and check
ifolder = INPUT_FOLDER / "UCI_winequality"
ofolder = OUTPUT_FOLDER / "UCI_winequality"
ofolder.mkdir(parents=True, exist_ok=True)
data_path = ifolder / "winequality-red.csv"
df1 = pd.read_csv(data_path, delimiter=";")
df1["winetype"] = "red"
data_path = ifolder / "winequality-white.csv"
df2 = pd.read_csv(data_path, delimiter=";")
df1["winetype"] = "white"
df = pd.concat([df1, df2])
print(df.head())
print(df.shape)

# set up metadata for GAN
df_meta = Metadata.detect_from_dataframe(df)
gen = CTGANSynthesizer(
    metadata=df_meta,
    epochs=500,
    verbose=True,
)

   fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \
0            7.4              0.70         0.00             1.9      0.076   
1            7.8              0.88         0.00             2.6      0.098   
2            7.8              0.76         0.04             2.3      0.092   
3           11.2              0.28         0.56             1.9      0.075   
4            7.4              0.70         0.00             1.9      0.076   

   free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \
0                 11.0                  34.0   0.9978  3.51       0.56   
1                 25.0                  67.0   0.9968  3.20       0.68   
2                 15.0                  54.0   0.9970  3.26       0.65   
3                 17.0                  60.0   0.9980  3.16       0.58   
4                 11.0                  34.0   0.9978  3.51       0.56   

   alcohol  quality  
0      9.4        5  
1      9.8        5  
2      9.8        5 


We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.



In [7]:
# Train the GAN - keep track of the time to execute
tstart = time.time()
gen.fit(df)
tend = time.time()
duration = tend - tstart

Gen. (-2.65) | Discrim. (-0.09): 100%|██████████| 500/500 [03:27<00:00,  2.41it/s]


In [None]:
# save the results, plot the loss function, and print the time to train the GAN
gen.save(ofolder / "ctgan.pkl")
df_meta.save_to_json(ofolder / "ctgan_metdata.json")
df.to_pickle(ofolder / "real_df.pkl")
# units are seconds, so display minutes
print(f"Time to fit: {(duration / 60):.2f} min.")

Time to fit: 3.60 min.


In [None]:
# do again, but train for different epochs/batches and save to evaluate later
epochs = [1, 5, 10, 25, 50, 100, 250, 500, 1000]
batches = [20, 50, 100, 500, 1000] # need to be in increments of 10
for epoch in epochs:
    for batch in batches:
        print(f"Currently fitting: epoch {epoch} batch {batch}.")
        grid_saveto = ofolder / f"grid/epoch {epoch} batch {batch}"
        grid_saveto.mkdir(parents=True, exist_ok=True)
        grid_gen = CTGANSynthesizer(
            metadata=df_meta,
            epochs=epoch,
            batch_size=batch,
            verbose=True,
        )
        tstart = time.time()
        grid_gen.fit(df)
        tend = time.time()
        duration = tend - tstart
        grid_gen.save(grid_saveto / "ctgan.pkl")
        # units are seconds, so display minutes
        print(f"Time to fit: {(duration / 60):.2f} min.")

Currently fitting: epoch 1 batch 20.


Gen. (1.23) | Discrim. (0.39): 100%|██████████| 1/1 [00:05<00:00,  5.22s/it]


Time to fit: 0.23 min.
Currently fitting: epoch 1 batch 50.


Gen. (1.62) | Discrim. (-0.06): 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]


Time to fit: 0.11 min.
Currently fitting: epoch 1 batch 100.


Gen. (0.94) | Discrim. (0.22): 100%|██████████| 1/1 [00:01<00:00,  1.28s/it]


Time to fit: 0.09 min.
Currently fitting: epoch 1 batch 500.


Gen. (1.93) | Discrim. (-0.03): 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


Time to fit: 0.08 min.
Currently fitting: epoch 1 batch 1000.


Gen. (2.03) | Discrim. (0.00): 100%|██████████| 1/1 [00:00<00:00,  2.96it/s]


Time to fit: 0.08 min.
Currently fitting: epoch 5 batch 20.


Gen. (-1.02) | Discrim. (0.18): 100%|██████████| 5/5 [00:25<00:00,  5.17s/it]


Time to fit: 0.50 min.
Currently fitting: epoch 5 batch 50.


Gen. (-1.23) | Discrim. (-0.10): 100%|██████████| 5/5 [00:11<00:00,  2.22s/it]


Time to fit: 0.26 min.
Currently fitting: epoch 5 batch 100.


Gen. (-1.02) | Discrim. (-0.18): 100%|██████████| 5/5 [00:06<00:00,  1.26s/it]


Time to fit: 0.18 min.
Currently fitting: epoch 5 batch 500.


Gen. (0.95) | Discrim. (0.02): 100%|██████████| 5/5 [00:02<00:00,  2.35it/s] 


Time to fit: 0.11 min.
Currently fitting: epoch 5 batch 1000.


Gen. (1.09) | Discrim. (0.12): 100%|██████████| 5/5 [00:01<00:00,  3.06it/s] 


Time to fit: 0.10 min.
Currently fitting: epoch 10 batch 20.


Gen. (-1.00) | Discrim. (-0.21): 100%|██████████| 10/10 [00:51<00:00,  5.18s/it]


Time to fit: 0.94 min.
Currently fitting: epoch 10 batch 50.


Gen. (-1.90) | Discrim. (-0.27): 100%|██████████| 10/10 [00:23<00:00,  2.31s/it]


Time to fit: 0.46 min.
Currently fitting: epoch 10 batch 100.


Gen. (-1.70) | Discrim. (0.13): 100%|██████████| 10/10 [00:12<00:00,  1.29s/it]


Time to fit: 0.29 min.
Currently fitting: epoch 10 batch 500.


Gen. (0.35) | Discrim. (0.03): 100%|██████████| 10/10 [00:04<00:00,  2.29it/s]


Time to fit: 0.15 min.
Currently fitting: epoch 10 batch 1000.


Gen. (0.97) | Discrim. (-0.07): 100%|██████████| 10/10 [00:03<00:00,  2.99it/s]


Time to fit: 0.13 min.
Currently fitting: epoch 25 batch 20.


Gen. (-0.90) | Discrim. (-2.62): 100%|██████████| 25/25 [02:06<00:00,  5.04s/it]


Time to fit: 2.17 min.
Currently fitting: epoch 25 batch 50.


Gen. (-2.27) | Discrim. (-0.37): 100%|██████████| 25/25 [00:55<00:00,  2.22s/it]


Time to fit: 1.00 min.
Currently fitting: epoch 25 batch 100.


Gen. (-2.77) | Discrim. (0.27): 100%|██████████| 25/25 [00:31<00:00,  1.24s/it] 


Time to fit: 0.59 min.
Currently fitting: epoch 25 batch 500.


Gen. (-1.31) | Discrim. (-0.15): 100%|██████████| 25/25 [00:10<00:00,  2.29it/s]


Time to fit: 0.25 min.
Currently fitting: epoch 25 batch 1000.


Gen. (-0.39) | Discrim. (0.02): 100%|██████████| 25/25 [00:08<00:00,  3.10it/s] 


Time to fit: 0.21 min.
Currently fitting: epoch 50 batch 20.


Gen. (0.05) | Discrim. (0.10): 100%|██████████| 50/50 [04:09<00:00,  4.99s/it]  


Time to fit: 4.23 min.
Currently fitting: epoch 50 batch 50.


Gen. (-1.39) | Discrim. (0.16): 100%|██████████| 50/50 [01:50<00:00,  2.22s/it] 


Time to fit: 1.92 min.
Currently fitting: epoch 50 batch 100.


Gen. (-3.05) | Discrim. (0.09): 100%|██████████| 50/50 [01:06<00:00,  1.33s/it] 


Time to fit: 1.18 min.
Currently fitting: epoch 50 batch 500.


Gen. (-1.87) | Discrim. (-0.08): 100%|██████████| 50/50 [00:22<00:00,  2.24it/s]


Time to fit: 0.44 min.
Currently fitting: epoch 50 batch 1000.


Gen. (-1.49) | Discrim. (0.01): 100%|██████████| 50/50 [00:17<00:00,  2.81it/s] 


Time to fit: 0.37 min.
Currently fitting: epoch 100 batch 20.


Gen. (-1.15) | Discrim. (-0.72): 100%|██████████| 100/100 [09:01<00:00,  5.42s/it]


Time to fit: 9.10 min.
Currently fitting: epoch 100 batch 50.


Gen. (-1.25) | Discrim. (0.90): 100%|██████████| 100/100 [04:04<00:00,  2.45s/it]


Time to fit: 4.23 min.
Currently fitting: epoch 100 batch 100.


Gen. (-3.30) | Discrim. (-0.07): 100%|██████████| 100/100 [02:07<00:00,  1.28s/it]


Time to fit: 2.20 min.
Currently fitting: epoch 100 batch 500.


Gen. (-2.76) | Discrim. (0.18): 100%|██████████| 100/100 [00:45<00:00,  2.21it/s]


Time to fit: 0.82 min.
Currently fitting: epoch 100 batch 1000.


Gen. (-2.11) | Discrim. (0.02): 100%|██████████| 100/100 [00:34<00:00,  2.91it/s]


Time to fit: 0.65 min.
Currently fitting: epoch 250 batch 20.


Gen. (0.04) | Discrim. (-0.33): 100%|██████████| 250/250 [22:32<00:00,  5.41s/it] 


Time to fit: 22.62 min.
Currently fitting: epoch 250 batch 50.


Gen. (-0.11) | Discrim. (-0.07): 100%|██████████| 250/250 [09:48<00:00,  2.35s/it]


Time to fit: 9.95 min.
Currently fitting: epoch 250 batch 100.


Gen. (-3.29) | Discrim. (-0.22): 100%|██████████| 250/250 [05:27<00:00,  1.31s/it]


Time to fit: 5.61 min.
Currently fitting: epoch 250 batch 500.


Gen. (-2.98) | Discrim. (0.05): 100%|██████████| 250/250 [01:53<00:00,  2.20it/s] 


Time to fit: 2.04 min.
Currently fitting: epoch 250 batch 1000.


Gen. (-2.40) | Discrim. (-0.07): 100%|██████████| 250/250 [01:26<00:00,  2.88it/s]


Time to fit: 1.52 min.
Currently fitting: epoch 500 batch 20.


Gen. (-0.83) | Discrim. (0.83): 100%|██████████| 500/500 [46:24<00:00,  5.57s/it] 


Time to fit: 46.50 min.
Currently fitting: epoch 500 batch 50.


Gen. (-1.05) | Discrim. (-0.48): 100%|██████████| 500/500 [19:39<00:00,  2.36s/it]


Time to fit: 19.81 min.
Currently fitting: epoch 500 batch 100.


Gen. (-3.00) | Discrim. (-0.23): 100%|██████████| 500/500 [10:56<00:00,  1.31s/it]


Time to fit: 11.10 min.
Currently fitting: epoch 500 batch 500.


Gen. (-2.93) | Discrim. (0.01): 100%|██████████| 500/500 [03:43<00:00,  2.24it/s] 


Time to fit: 3.87 min.
Currently fitting: epoch 500 batch 1000.


Gen. (-2.03) | Discrim. (-0.02): 100%|██████████| 500/500 [02:56<00:00,  2.83it/s]


Time to fit: 3.01 min.
Currently fitting: epoch 1000 batch 20.


Gen. (0.44) | Discrim. (-1.28): 100%|██████████| 1000/1000 [1:27:29<00:00,  5.25s/it]


Time to fit: 87.57 min.
Currently fitting: epoch 1000 batch 50.


Gen. (-0.25) | Discrim. (-0.44): 100%|██████████| 1000/1000 [38:28<00:00,  2.31s/it]


Time to fit: 38.61 min.
Currently fitting: epoch 1000 batch 100.


Gen. (-0.86) | Discrim. (-0.01): 100%|██████████| 1000/1000 [20:07<00:00,  1.21s/it]


Time to fit: 20.26 min.
Currently fitting: epoch 1000 batch 500.


Gen. (-3.17) | Discrim. (-0.01): 100%|██████████| 1000/1000 [07:16<00:00,  2.29it/s]


Time to fit: 7.42 min.
Currently fitting: epoch 1000 batch 1000.


Gen. (-0.88) | Discrim. (-0.22): 100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]

Time to fit: 5.51 min.





In [None]:
# Kaggle credit card fraud data
# Read data and check
ifolder = INPUT_FOLDER / "Kaggle_creditcardfraud"
ofolder = OUTPUT_FOLDER / "Kaggle_creditcardfraud"
ofolder.mkdir(parents=True, exist_ok=True)
data_path = INPUT_FOLDER / "creditcard.csv"
df = pd.read_csv(data_path)
print(df.head())
print(df.shape)

# set up metadata for GAN
df_meta = Metadata.detect_from_dataframe(df)
gen = CTGANSynthesizer(
    metadata=df_meta,
    epochs=500,
    verbose=True,
)

   Time        V1        V2        V3        V4        V5        V6        V7  \
0   0.0 -1.359807 -0.072781  2.536347  1.378155 -0.338321  0.462388  0.239599   
1   0.0  1.191857  0.266151  0.166480  0.448154  0.060018 -0.082361 -0.078803   
2   1.0 -1.358354 -1.340163  1.773209  0.379780 -0.503198  1.800499  0.791461   
3   1.0 -0.966272 -0.185226  1.792993 -0.863291 -0.010309  1.247203  0.237609   
4   2.0 -1.158233  0.877737  1.548718  0.403034 -0.407193  0.095921  0.592941   

         V8        V9  ...       V21       V22       V23       V24       V25  \
0  0.098698  0.363787  ... -0.018307  0.277838 -0.110474  0.066928  0.128539   
1  0.085102 -0.255425  ... -0.225775 -0.638672  0.101288 -0.339846  0.167170   
2  0.247676 -1.514654  ...  0.247998  0.771679  0.909412 -0.689281 -0.327642   
3  0.377436 -1.387024  ... -0.108300  0.005274 -0.190321 -1.175575  0.647376   
4 -0.270533  0.817739  ... -0.009431  0.798278 -0.137458  0.141267 -0.206010   

        V26       V27       V28 


We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.



In [14]:
# Train the GAN - keep track of the time to execute
tstart = time.time()
gen.fit(df)
tend = time.time()
duration = tend - tstart

Gen. (0.30) | Discrim. (-0.22): 100%|██████████| 500/500 [4:36:08<00:00, 33.14s/it]   


In [None]:
# save the results, plot the loss function, and print the time to train the GAN
gen.save(ofolder / "ctgan.pkl")
df_meta.save_to_json(ofolder / "ctgan_metdata.json")
df.to_pickle(ofolder / "real_df.pkl")
# units are seconds, so display minutes
print(f"Time to fit: {(duration / 60):.2f} min.")

Time to fit: 288.82 min.


In [None]:
# FEMA disasters
# Read data and check
# Requires a little more data transformation
ifolder = INPUT_FOLDER
ofolder = OUTPUT_FOLDER / "FEMA_DisasterDeclarationsSummaries"
ofolder.mkdir(parents=True, exist_ok=True)
data_path = ifolder / "FEMA_DisasterDeclarationsSummaries.json"
temp = None
with open(data_path, "r") as io:
    temp = json.load(io)
# has a lot of columns not needed
orig_df = pd.DataFrame(temp.get("DisasterDeclarationsSummaries"))

# read about the columns here: https://www.fema.gov/openfema-data-page/disaster-declarations-summaries-v2
df = orig_df[
    [
        "state",
        "fipsCountyCode",
        "incidentType",
        "ihProgramDeclared",
        "iaProgramDeclared",
        "paProgramDeclared",
        "hmProgramDeclared",
        "incidentBeginDate",
        "incidentEndDate",
    ]
]
print(df.head())
print(df.shape)

  state fipsCountyCode incidentType  ihProgramDeclared  iaProgramDeclared  \
0    OR            067         Fire              False              False   
1    OR            031         Fire              False              False   
2    OR            017         Fire              False              False   
3    WA            077         Fire              False              False   
4    ID            000         Fire              False              False   

   paProgramDeclared  hmProgramDeclared         incidentBeginDate  \
0               True               True  2024-08-08T00:00:00.000Z   
1               True               True  2024-08-04T00:00:00.000Z   
2               True               True  2024-08-02T00:00:00.000Z   
3               True               True  2024-07-23T00:00:00.000Z   
4               True               True  2024-07-25T00:00:00.000Z   

  incidentEndDate  
0            None  
1            None  
2            None  
3            None  
4            None  
(6

In [32]:
# do various transformations
def f(row: pd.Series) -> str:
    x = row["incidentBeginDate"]
    try:
        x = datetime.datetime.strptime(x[:10], "%Y-%m-%d")
        return x.strftime("%Y")
    except Exception as ex:
        print(str(ex))
        return None


df["incidentYear"] = df.apply(f, axis=1)


def f(row: pd.Series) -> str:
    x = row["incidentBeginDate"]
    try:
        x = datetime.datetime.strptime(x[:10], "%Y-%m-%d")
        return x.strftime("%m")
    except Exception as ex:
        print(str(ex))
        return None


df["incidentMonth"] = df.apply(f, axis=1)


def f(row: pd.Series) -> str:
    x = row["incidentBeginDate"]
    y = row["incidentEndDate"]
    if y is None:
        return None
    try:
        x = datetime.datetime.strptime(x[:10], "%Y-%m-%d")
        y = datetime.datetime.strptime(y[:10], "%Y-%m-%d")
        d = y - x
        return d.days
    except Exception as ex:
        print(str(ex))
        return None


df["incidentDurationDays"] = df.apply(f, axis=1)

df.drop(["incidentBeginDate"], axis=1, inplace=True)
df.drop(["incidentEndDate"], axis=1, inplace=True)
df = df.replace({float("nan"): None})
df.head()
df.describe()

Unnamed: 0,state,fipsCountyCode,incidentType,ihProgramDeclared,iaProgramDeclared,paProgramDeclared,hmProgramDeclared,incidentYear,incidentMonth,incidentDurationDays
count,67375,67375,67375,67375,67375,67375,67375,67375,67375,66873.0
unique,59,347,26,2,2,2,2,73,12,152.0
top,TX,0,Severe Storm,False,False,True,False,2020,1,0.0
freq,5350,1544,18402,55933,50188,62927,37502,9706,12945,9356.0


In [33]:
# set up metadata for GAN
df_meta = Metadata.detect_from_dataframe(df)
gen = CTGANSynthesizer(
    metadata=df_meta,
    epochs=500,
    verbose=True,
)


We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.



In [34]:
# Train the GAN - keep track of the time to execute
tstart = time.time()
gen.fit(df)
tend = time.time()
duration = tend - tstart

Gen. (-0.31) | Discrim. (0.36):  56%|█████▌    | 278/500 [1:00:59<48:42, 13.16s/it] 


KeyboardInterrupt: 

In [None]:
# save the results, plot the loss function, and print the time to train the GAN
gen.save(ofolder / "ctgan.pkl")
df_meta.save_to_json(ofolder / "ctgan_metdata.json")
df.to_pickle(ofolder / "real_df.pkl")
orig_df.to_pickle(ofolder / "orig_df.pkl")
# units are seconds, so display minutes
print(f"Time to fit: {(duration / 60):.2f} min.")


You are saving a synthesizer that has not yet been fitted. You will not be able to sample synthetic data without fitting. We recommend fitting the synthesizer first and then saving.



Time to fit: 288.82 min.
