<a href="https://colab.research.google.com/github/henomoto1025/synthetic-data-notebooks/blob/main/Synthcity.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## インストール

In [1]:
!pip install synthcity



## 学習元データの読み込み

In [2]:
import pandas as pd

csv_url = "https://huggingface.co/datasets/inria-soda/tabular-benchmark/raw/main/reg_num/wine_quality.csv"
df_real = pd.read_csv(csv_url)
df_real

Unnamed: 0,fixed.acidity,volatile.acidity,citric.acid,residual.sugar,chlorides,free.sulfur.dioxide,total.sulfur.dioxide,density,pH,sulphates,alcohol,quality
0,7.4,0.70,0.00,1.9,0.076,11.0,34.0,0.99780,3.51,0.56,9.4,5
1,7.8,0.88,0.00,2.6,0.098,25.0,67.0,0.99680,3.20,0.68,9.8,5
2,7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.99700,3.26,0.65,9.8,5
3,11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.99800,3.16,0.58,9.8,6
4,7.4,0.70,0.00,1.9,0.076,11.0,34.0,0.99780,3.51,0.56,9.4,5
...,...,...,...,...,...,...,...,...,...,...,...,...
6492,6.2,0.21,0.29,1.6,0.039,24.0,92.0,0.99114,3.27,0.50,11.2,6
6493,6.6,0.32,0.36,8.0,0.047,57.0,168.0,0.99490,3.15,0.46,9.6,5
6494,6.5,0.24,0.19,1.2,0.041,30.0,111.0,0.99254,2.99,0.46,9.4,6
6495,5.5,0.29,0.30,1.1,0.022,20.0,110.0,0.98869,3.34,0.38,12.8,7


## Synthcityプラグインの初期化

In [3]:
from synthcity.plugins import Plugins

Plugins(categories=["generic", "privacy"]).list()

[2024-10-09T08:48:46.968458+0000][4615][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py


['tvae',
 'bayesian_network',
 'ctgan',
 'nflow',
 'adsgan',
 'aim',
 'dpgan',
 'dummy_sampler',
 'marginal_distributions',
 'uniform_sampler',
 'ddpm',
 'pategan',
 'decaf',
 'great',
 'arf',
 'rtvae',
 'privbayes']

In [4]:
from synthcity.plugins import Plugins

# n_iter: エポック数、初期値は1000
model = Plugins().get("tvae", n_iter=100)

[2024-10-09T08:48:47.825965+0000][4615][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py
[2024-10-09T08:48:47.825965+0000][4615][CRITICAL] module disabled: /usr/local/lib/python3.10/dist-packages/synthcity/plugins/generic/plugin_goggle.py


## TVAEの学習

In [5]:
model.fit(df_real)

100%|██████████| 100/100 [00:52<00:00,  1.91it/s]


<synthcity.plugins.generic.plugin_tvae.TVAEPlugin at 0x7b1fb87be6e0>

## 合成データの生成

In [6]:
n_samples = len(df_real)  # 実データとの比較のため同数をサンプリング
df_fake = model.generate(n_samples).dataframe()
df_fake

Unnamed: 0,fixed.acidity,volatile.acidity,citric.acid,residual.sugar,chlorides,free.sulfur.dioxide,total.sulfur.dioxide,density,pH,sulphates,alcohol,quality
0,7.292578,0.405549,0.291676,13.061412,0.081060,16.867325,56.608618,0.996364,3.215700,0.758424,10.924050,5
1,6.484425,0.262029,0.317198,1.343769,0.032651,38.379586,120.685488,0.991030,3.096187,0.653125,12.550117,6
2,6.007896,0.226987,0.243233,1.742370,0.040186,17.062752,114.576735,0.990299,3.224001,0.354428,12.664931,6
3,5.617659,0.281146,0.253631,7.164859,0.042467,29.160689,93.816037,0.991015,3.387058,0.378551,12.272254,7
4,6.422022,0.274756,0.481685,13.471893,0.145601,54.349765,207.325948,0.995962,3.108027,0.488421,9.274753,6
...,...,...,...,...,...,...,...,...,...,...,...,...
6492,11.044282,0.411075,0.528136,1.837356,0.082841,17.248673,25.617176,0.997192,3.018828,0.664517,9.329651,5
6493,8.475172,0.240910,0.706108,3.598939,0.056758,24.415972,123.953081,0.995958,3.072013,0.628126,9.444921,5
6494,10.453969,0.345954,0.498133,2.371383,0.032739,6.650552,21.944919,0.997200,3.048822,0.518269,11.050867,6
6495,7.281383,0.588537,0.394328,16.350679,0.051857,65.744272,205.446784,0.998677,3.183857,0.607332,9.086747,6


## 合成データの評価

In [7]:
!pip install sdv



In [8]:
from sdv.evaluation.single_table import evaluate_quality
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()
metadata.detect_from_dataframe(df_real)

quality_report = evaluate_quality(
    real_data=df_real,
    synthetic_data=df_fake,
    metadata=metadata
)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 12/12 [00:00<00:00, 187.77it/s]|
Column Shapes Score: 91.12%

(2/2) Evaluating Column Pair Trends: |██████████| 66/66 [00:00<00:00, 140.19it/s]|
Column Pair Trends Score: 89.71%

Overall Score (Average): 90.41%



## 評価の視覚化

In [9]:
quality_report.get_visualization("Column Pair Trends")

In [10]:
quality_report.get_visualization("Column Shapes")

In [11]:
from sdmetrics.visualization import get_column_plot

for col in df_real.columns:
    if metadata.columns[col]["sdtype"] == "categorical":
        fig = get_column_plot(
            df_real,
            df_fake,
            col,
            plot_type="bar",
        )
    else:
        fig = get_column_plot(
            df_real,
            df_fake,
            col,
            plot_type="distplot"
        )
    fig.show()