In [45]:
import pickle
import re
from pathlib import Path

import pandas as pd
from sdv.evaluation.single_table import (
    evaluate_quality,
)
from tqdm import tqdm

PROJECT_ROOT = Path(__name__).resolve().parent.parent.parent
INPUT_FOLDER = PROJECT_ROOT / "data/input"
OUTPUT_FOLDER = PROJECT_ROOT / "data/output"
OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)

In [31]:
# Get folders and pkl files to loop through
#  Will create data for a plot in excel
ifolder = INPUT_FOLDER / "UCI_winequality"
ofolder = OUTPUT_FOLDER / "UCI_winequality"

# Get original data
real_df = pd.read_pickle(ofolder / "real_df.pkl")

# Create table of metadata based on folder name
folders = (ofolder / "grid").glob("*")
folders_meta = []

for folder in folders:
    e = re.sub("batch.*", "", folder.stem)
    e = re.sub("epoch", "", e)
    e = int(e.strip())
    b = re.sub(".*batch", "", folder.stem)
    b = int(b.strip())
    folders_meta.append({"epoch": e, "batch": b, "pkl": folder / "ctgan.pkl"})

In [40]:
# Loop through each model and find the overall data quality score
for item in tqdm(folders_meta):
    gen = None
    with open(item.get("pkl"), "rb") as io:
        gen = pickle.load(io)
    fake_df = gen.sample(num_rows=real_df.shape[0])
    gen_quality = evaluate_quality(
        real_data=real_df, synthetic_data=fake_df, metadata=gen.get_metadata(), verbose = False
    )
    temp = gen_quality.get_properties()
    res = {
        "OverallScore": float(gen_quality.get_score()),
        "Column Shapes": float(temp[temp.Property == "Column Shapes"].iloc[0, 1]),
        "Column Pair Trends": float(temp[temp.Property == "Column Pair Trends"].iloc[0, 1]),
    }
    item.update(res)


100%|██████████| 45/45 [00:24<00:00,  1.81it/s]


In [55]:
folders_df = pd.DataFrame(folders_meta)
folders_df.drop(["pkl"], inplace = True, axis=1)
folders_df.head()
folders_df.to_csv(ofolder / "grid_results.csv", index=False)