In [11]:
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.io as pio
from sdv.evaluation.single_table import (
    evaluate_quality,
    get_column_plot,
    run_diagnostic,
)

pio.renderers.default = "vscode"

PROJECT_ROOT = Path(__name__).resolve().parent.parent.parent
INPUT_FOLDER = PROJECT_ROOT / "data/input"
OUTPUT_FOLDER = PROJECT_ROOT / "data/output"
OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)

In [None]:
# Read the model, look at diagnostics, create some data
ifolder = INPUT_FOLDER / "UCI_adult"
ofolder = OUTPUT_FOLDER / "UCI_adult"
gen = None
with open(ofolder / "ctgan.pkl", "rb") as io:
    gen = pickle.load(io)

# Get original data
real_df = pd.read_pickle(ofolder / "real_df.pkl")
# Generate some fake data
fake_df = gen.sample(num_rows=real_df.shape[0])

In [13]:
# Check loss function
gen.get_loss_values_plot().show()

In [14]:
# Check built in diagnostics
gen_diagnostics = run_diagnostic(
    real_data=real_df, synthetic_data=fake_df, metadata=gen.get_metadata()
)

Generating report ...

(1/2) Evaluating Data Validity: |██████████| 16/16 [00:00<00:00, 292.35it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 505.16it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%



In [15]:
gen_diagnostics.get_visualization(property_name="Data Validity")

In [16]:
# check built in quality
gen_quality = evaluate_quality(
    real_data=real_df, synthetic_data=fake_df, metadata=gen.get_metadata()
)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 16/16 [00:00<00:00, 100.78it/s]|
Column Shapes Score: 86.51%

(2/2) Evaluating Column Pair Trends: |██████████| 120/120 [00:01<00:00, 97.03it/s]|
Column Pair Trends Score: 82.78%

Overall Score (Average): 84.65%



In [17]:
gen_quality.get_visualization(property_name="Column Shapes")
gen_quality.get_visualization(property_name="Column Pair Trends")

In [18]:
# Distributions do not match well, this is why the correlation is so poor
col = "Income_Category"
get_column_plot(real_df, fake_df, column_name=col, metadata=gen.metadata)

In [19]:
fake_df.to_pickle(ofolder / "syn_df.pkl")