In [1]:
import numpy as np
import pandas as pd
from sdv_train_moneyball import load_data, get_moneyball_metadata, convert_category_data
from sdv.tabular import CopulaGAN
from sdmetrics.reports.single_table import QualityReport, DiagnosticReport
from sdmetrics.reports import utils


In [2]:
data_real = load_data(41021)
data_real = convert_category_data(data_real)
table_metadata = get_moneyball_metadata()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1232 entries, 0 to 1231
Data columns (total 15 columns):
 #   Column        Non-Null Count  Dtype   
---  ------        --------------  -----   
 0   Team          1232 non-null   category
 1   League        1232 non-null   category
 2   Year          1232 non-null   int64   
 3   RS            1232 non-null   int64   
 4   RA            1232 non-null   int64   
 5   W             1232 non-null   uint8   
 6   OBP           1232 non-null   float64 
 7   SLG           1232 non-null   float64 
 8   BA            1232 non-null   float64 
 9   Playoffs      1232 non-null   category
 10  RankSeason    244 non-null    category
 11  RankPlayoffs  244 non-null    category
 12  G             1232 non-null   category
 13  OOBP          420 non-null    float64 
 14  OSLG          420 non-null    float64 
dtypes: category(6), float64(5), int64(3), uint8(1)
memory usage: 88.1 KB
None


In [3]:
model = CopulaGAN.load("sdv_copulagan_moneyball.pkl")


In [4]:
data_synth = model.sample(num_rows=10_000)


In [5]:
data_synth


Unnamed: 0,Team,League,Year,RS,RA,W,OBP,SLG,BA,Playoffs,RankSeason,RankPlayoffs,G,OOBP,OSLG
0,NYM,NL,1994,729,682,85,0.330539,0.402151,0.259320,1,6,1,162,0.320191,0.390059
1,MIN,AL,1999,851,773,85,0.362783,0.437339,0.285191,0,,,162,0.339547,0.420715
2,KCR,AL,1978,724,728,79,0.334482,0.401055,0.273034,0,,,161,,
3,BAL,AL,1990,676,765,61,0.314917,0.375985,0.250967,0,,,159,,
4,TEX,AL,1969,730,695,83,0.328199,0.398837,0.269390,0,,,162,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,MIN,AL,1970,686,692,74,0.325259,0.351936,0.270263,0,,,159,,
9996,SFG,NL,2010,807,875,83,0.341341,0.438514,0.275597,1,3,4,162,0.333829,0.427839
9997,SFG,NL,1966,544,540,82,0.305889,0.321007,0.224577,0,4,4,162,,
9998,MIN,AL,2007,704,650,84,0.329475,0.395216,0.267514,0,4,2,162,0.318113,0.395158


## Quality Report

In [6]:
quality_report = QualityReport()
quality_report.generate(data_real, data_synth, table_metadata.to_dict())


Creating report: 100%|██████████| 4/4 [00:00<00:00,  4.66it/s]



Overall Quality Score: 89.1%

Properties:
Column Shapes: 90.19%
Column Pair Trends: 88.02%


In [7]:
for name in ["Column Shapes", "Column Pair Trends"]:
    fig = quality_report.get_visualization(property_name=name)
    fig.show()


## Diagnostic Report

In [8]:
diagnostic_report = DiagnosticReport()
diagnostic_report.generate(data_real, data_synth, table_metadata.to_dict())


Creating report: 100%|██████████| 4/4 [00:15<00:00,  3.77s/it]


DiagnosticResults:

SUCCESS:
✓ The synthetic data covers over 90% of the numerical ranges present in the real data
✓ The synthetic data covers over 90% of the categories present in the real data
✓ Over 90% of the synthetic rows are not copies of the real data
✓ The synthetic data follows over 90% of the min/max boundaries set by the real data





In [9]:
for name in ["Synthesis", "Coverage", "Boundaries"]:
    fig = diagnostic_report.get_visualization(property_name=name)
    fig.show()


In [10]:
for col in data_real.columns:
    fig = utils.get_column_plot(
        real_data=data_real,
        synthetic_data=data_synth,
        column_name=col,
        metadata=table_metadata.to_dict(),
    )
    fig.show()
