- Ref: https://docs.sdv.dev/sdv/

In [148]:
%load_ext lab_black
import pandas as pd
from sdv.datasets.demo import download_demo, get_available_demos
from sdv.metadata import SingleTableMetadata
from sdv.single_table import (
    CTGANSynthesizer,
    TVAESynthesizer,
    GaussianCopulaSynthesizer,
    CopulaGANSynthesizer,
)
from sdv.lite import SingleTablePreset
from sdv.evaluation.single_table import evaluate_quality
import warnings
import time

warnings.filterwarnings("ignore")

The lab_black extension is already loaded. To reload it, use:
  %reload_ext lab_black


In [149]:
cardio = pd.read_csv("cardio_final.csv")
cardio

Unnamed: 0,id,age,gender,height,weight,systolic,diastolic,cholesterol,glucose,smoke,alcohol_intake,physical_activity,cv_disease,bmi
0,0,51,Male,168,62.0,110,80,Normal,Normal,False,False,True,False,22.0
1,1,56,Female,156,85.0,140,90,Extremely High,Normal,False,False,True,True,34.9
2,2,52,Female,165,64.0,130,70,Extremely High,Normal,False,False,False,True,23.5
3,3,49,Male,169,82.0,150,100,Normal,Normal,False,False,True,True,28.7
4,4,48,Female,156,56.0,100,60,Normal,Normal,False,False,False,False,23.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69995,99993,53,Male,168,76.0,120,80,Normal,Normal,True,False,True,False,26.9
69996,99995,62,Female,158,126.0,140,90,High,High,False,False,True,True,50.5
69997,99996,53,Male,183,105.0,180,90,Extremely High,Normal,False,True,False,True,31.4
69998,99998,62,Female,163,72.0,135,80,Normal,High,False,False,False,True,27.1


In [150]:
cardio.age.value_counts()

56    3941
54    3891
58    3708
60    3601
57    3592
55    3583
50    3435
59    3387
52    3385
53    3262
51    3197
61    3176
64    2746
62    2738
48    2200
63    2188
65    2177
46    2098
44    2042
42    1912
49    1808
40    1788
41    1616
47    1615
45    1503
43    1407
30       3
31       1
Name: age, dtype: int64

In [151]:
bins = [0, 30, 60, 100]

"""
Youth - <30
Adults - 30 ~ 60
Elderly - >60

"""
cardio["range_age"] = pd.cut(
    cardio["age"], bins=bins, labels=["Youth", "Adults", "Elderly"]
).astype("object")
cardio = cardio[
    [
        "id",
        "age",
        "range_age",
        "gender",
        "height",
        "weight",
        "systolic",
        "diastolic",
        "cholesterol",
        "glucose",
        "smoke",
        "alcohol_intake",
        "physical_activity",
        "cv_disease",
        "bmi",
    ]
]
cardio

Unnamed: 0,id,age,range_age,gender,height,weight,systolic,diastolic,cholesterol,glucose,smoke,alcohol_intake,physical_activity,cv_disease,bmi
0,0,51,Adults,Male,168,62.0,110,80,Normal,Normal,False,False,True,False,22.0
1,1,56,Adults,Female,156,85.0,140,90,Extremely High,Normal,False,False,True,True,34.9
2,2,52,Adults,Female,165,64.0,130,70,Extremely High,Normal,False,False,False,True,23.5
3,3,49,Adults,Male,169,82.0,150,100,Normal,Normal,False,False,True,True,28.7
4,4,48,Adults,Female,156,56.0,100,60,Normal,Normal,False,False,False,False,23.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69995,99993,53,Adults,Male,168,76.0,120,80,Normal,Normal,True,False,True,False,26.9
69996,99995,62,Elderly,Female,158,126.0,140,90,High,High,False,False,True,True,50.5
69997,99996,53,Adults,Male,183,105.0,180,90,Extremely High,Normal,False,True,False,True,31.4
69998,99998,62,Elderly,Female,163,72.0,135,80,Normal,High,False,False,False,True,27.1


In [152]:
cardio["range_age"].value_counts()

Adults     56972
Elderly    13025
Youth          3
Name: range_age, dtype: int64

In [153]:
cardio.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 70000 entries, 0 to 69999
Data columns (total 15 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   id                 70000 non-null  int64  
 1   age                70000 non-null  int64  
 2   range_age          70000 non-null  object 
 3   gender             70000 non-null  object 
 4   height             70000 non-null  int64  
 5   weight             70000 non-null  float64
 6   systolic           70000 non-null  int64  
 7   diastolic          70000 non-null  int64  
 8   cholesterol        70000 non-null  object 
 9   glucose            70000 non-null  object 
 10  smoke              70000 non-null  bool   
 11  alcohol_intake     70000 non-null  bool   
 12  physical_activity  70000 non-null  bool   
 13  cv_disease         70000 non-null  bool   
 14  bmi                70000 non-null  float64
dtypes: bool(4), float64(2), int64(5), object(4)
memory usage: 6.1+ MB


## Metadata

In [154]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=cardio)

In [155]:
metadata

{
    "columns": {
        "id": {
            "sdtype": "numerical"
        },
        "age": {
            "sdtype": "numerical"
        },
        "range_age": {
            "sdtype": "categorical"
        },
        "gender": {
            "sdtype": "categorical"
        },
        "height": {
            "sdtype": "numerical"
        },
        "weight": {
            "sdtype": "numerical"
        },
        "systolic": {
            "sdtype": "numerical"
        },
        "diastolic": {
            "sdtype": "numerical"
        },
        "cholesterol": {
            "sdtype": "categorical"
        },
        "glucose": {
            "sdtype": "categorical"
        },
        "smoke": {
            "sdtype": "boolean"
        },
        "alcohol_intake": {
            "sdtype": "boolean"
        },
        "physical_activity": {
            "sdtype": "boolean"
        },
        "cv_disease": {
            "sdtype": "boolean"
        },
        "bmi": {
            "sdtype": "nu

In [156]:
# check errors

metadata.validate()

In [157]:
metadata.update_column(column_name="id", sdtype="id", regex_format="[0-9]{5}")


metadata.update_column(
    column_name="age", sdtype="numerical", computer_representation="Int32"
)

metadata.update_column(
    column_name="height", sdtype="numerical", computer_representation="Int64"
)

metadata.update_column(
    column_name="weight", sdtype="numerical", computer_representation="Float"
)

metadata.update_column(
    column_name="systolic", sdtype="numerical", computer_representation="Int64"
)

metadata.update_column(
    column_name="diastolic", sdtype="numerical", computer_representation="Int64"
)

metadata.update_column(
    column_name="bmi", sdtype="numerical", computer_representation="Float"
)

metadata.update_column(column_name="smoke", sdtype="boolean")

metadata.update_column(column_name="alcohol_intake", sdtype="boolean")

metadata.update_column(column_name="physical_activity", sdtype="boolean")

metadata.update_column(column_name="cv_disease", sdtype="boolean")

metadata.update_column(column_name="range_age", sdtype="categorical")

In [158]:
metadata.set_primary_key(column_name="id")

In [159]:
metadata

{
    "columns": {
        "id": {
            "sdtype": "id",
            "regex_format": "[0-9]{5}"
        },
        "age": {
            "sdtype": "numerical",
            "computer_representation": "Int32"
        },
        "range_age": {
            "sdtype": "categorical"
        },
        "gender": {
            "sdtype": "categorical"
        },
        "height": {
            "sdtype": "numerical",
            "computer_representation": "Int64"
        },
        "weight": {
            "sdtype": "numerical",
            "computer_representation": "Float"
        },
        "systolic": {
            "sdtype": "numerical",
            "computer_representation": "Int64"
        },
        "diastolic": {
            "sdtype": "numerical",
            "computer_representation": "Int64"
        },
        "cholesterol": {
            "sdtype": "categorical"
        },
        "glucose": {
            "sdtype": "categorical"
        },
        "smoke": {
            "sdtype": "b

In [160]:
cardio.glucose.value_counts()

Normal            59479
Extremely High     5331
High               5190
Name: glucose, dtype: int64

In [161]:
cardio.cholesterol.value_counts()

Normal            52385
High               9549
Extremely High     8066
Name: cholesterol, dtype: int64

In [162]:
cardio.range_age.value_counts()

Adults     56972
Elderly    13025
Youth          3
Name: range_age, dtype: int64

In [163]:
sample = 20
cardio = cardio.sample(sample)

## Constraints

In [164]:
age_constraint = {
    "constraint_class": "ScalarRange",
    "constraint_parameters": {
        "column_name": "age",
        "low_value": 30.0,
        "high_value": 70.0,
        "strict_boundaries": False,
    },
}


# age_and_rangeage_constraint = {
#     "constraint_class": "FixedCombinations",
#     "constraint_parameters": {"column_names": ["age", "range_age"]},
# }

## Synthetize

In [165]:
synthesizer = CTGANSynthesizer(metadata)
synthesizer.add_constraints(constraints=[age_constraint])

In [167]:
%%time
print(f"Start synthetizer \t number of samples: {sample}")
synthesizer.fit(cardio)

Start synthetizer 	 number of samples: 20
CPU times: total: 28min 21s
Wall time: 4min 44s


In [16]:
synthesizer.save("my_synthesizer.pkl")

synthesizer = CTGANSynthesizer.load("my_synthesizer.pkl")

In [11]:
synthetic_data = synthesizer.sample(num_rows=100)
synthetic_data.head()

Unnamed: 0,id,age,gender,height,weight,systolic,diastolic,cholesterol,glucose,smoke,alcohol_intake,physical_activity,cv_disease,bmi
0,0,48,Female,148,79.0,113,76,Normal,High,False,False,True,False,28.7
1,1,48,Male,148,52.0,90,60,Normal,Normal,False,False,True,True,32.0
2,2,58,Female,149,59.0,137,92,Normal,Normal,False,False,True,True,24.7
3,3,62,Female,148,52.0,94,79,High,High,False,False,False,False,30.8
4,4,44,Female,148,51.0,140,68,Normal,Normal,False,False,True,False,21.0


In [13]:
quality_report = evaluate_quality(cardio, synthetic_data, metadata)

Creating report: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 11.73it/s]



Overall Quality Score: 74.47%

Properties:
Column Shapes: 79.08%
Column Pair Trends: 69.86%
