# 02 - Synthetic Data Generation with SDV

**Purpose**: This notebook demonstrates a more advanced and customizable approach to synthetic data generation using the `sdv` (Synthetic Data Vault) library. It provides fine-grained control over the statistical models used for each feature.

**Inputs**:
- Real data, accessed via the `BayesianData` class.
- A predefined list of `keeps` features, likely a subset identified from previous analyses.

**Outputs**:
- `sdv_metadata.json`: A JSON file defining the schema and data types for the SDV synthesizer.
- `synth_sdv_{NUM_ROWS}_long.ipc`: The final synthetic dataset, saved in the long format.
- In-line quality reports (`Data Validity Score`, `Data Structure Score`, `Column Shapes Score`) from `sdv`.
- A Plotly figure showing the distribution comparisons for all `keeps` features.

### 2.1 Setup, Metadata Configuration, and Synthesis

This cell performs the entire SDV workflow:
1.  **Setup**: Imports `sdv` components and defines the number of synthetic rows to generate.
2.  **Feature Subsetting**: Selects a specific subset of `keeps` features from the full dataset to model.
3.  **Metadata Generation**: Automatically detects metadata from the real data and saves it to `sdv_metadata.json`. It also explicitly sets the `risk` column to be treated as categorical.
4.  **Distribution Specification**: Manually defines which statistical distribution (`gaussian_kde`, `beta`, `norm`) should be used to model each specific feature in the `distros` dictionary. This provides granular control over the synthesis process.
5.  **Synthesizer Configuration**: Instantiates the `GaussianCopulaSynthesizer` with the defined metadata and distributions. It also adds a `FixedCombinations` constraint to ensure that the relationships between `risk` and `category` are preserved from the real data.
6.  **Fit and Sample**: Fits the synthesizer to the real data and samples the specified number of new synthetic rows.

In [1]:
%reload_ext autoreload
%autoreload 2

import math

import polars as pl
from polars import DataFrame
import pandas as pd

from sdv.single_table import CTGANSynthesizer, CopulaGANSynthesizer, GaussianCopulaSynthesizer
from sdv.metadata import Metadata
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality

from early_markers.cribsy.common.bayes import BayesianData
from early_markers.cribsy.common.constants import (
    IPC_DIR,
    JSON_DIR,
)


NUM_ROWS = 1000
RISK_0_ROWS = math.ceil(0.740741 * NUM_ROWS)
RISK_1_ROWS = NUM_ROWS - RISK_0_ROWS

keeps = [
    'Ankle_IQRx',
    'Ankle_IQRy',
    'Ankle_medianvelx',
    'Ankle_medianvely',
    'Ankle_mediany',
    'Ear_lrCorr_x',
    'Elbow_entropy_angle',
    'Elbow_lrCorr_angle',
    'Elbow_median_vel_angle',
    'Hip_mean_angle',
    'Hip_median_vel_angle',
    'Hip_stdev_angle',
    'Knee_lrCorr_x',
    'Knee_mean_angle',
    'Knee_median_vel_angle',
    'Knee_stdev_angle',
    'Shoulder_median_vel_angle',
    'Wrist_IQRx',
    'Wrist_lrCorr_x',
    'Wrist_medianvelx',
    'Wrist_medianvely',
    'Wrist_medianx',
    'age_in_weeks'
]

distros = {
    # "risk_raw": "beta",
    # "category": "beta",
    # "Ankle_IQRaccx": "gaussian_kde",
    # "Ankle_IQRaccy": "gaussian_kde",
    # "Ankle_IQRvelx": "gaussian_kde",
    # "Ankle_IQRvely": "gaussian_kde",
    "Ankle_IQRx": "gaussian_kde",
    "Ankle_IQRy": "gaussian_kde",
    # "Ankle_lrCorr_x": "beta",
    # "Ankle_meanent": "gaussian_kde",
    "Ankle_medianvelx": "gaussian_kde",
    "Ankle_medianvely": "norm",
    # "Ankle_medianx": "norm",
    "Ankle_mediany": "gaussian_kde",
    "Ear_lrCorr_x": "beta",
    # "Elbow_IQR_acc_angle": "gaussian_kde",
    # "Elbow_IQR_vel_angle": "gaussian_kde",
    "Elbow_entropy_angle": "gaussian_kde",
    "Elbow_lrCorr_angle": "norm",
    # "Elbow_lrCorr_x": "norm",
    # "Elbow_mean_angle": "gaussian_kde",
    "Elbow_median_vel_angle": "gaussian_kde",
    # "Elbow_stdev_angle": "norm",
    # "Eye_lrCorr_x": "norm",  # ***
    # "Hip_IQR_acc_angle": "gaussian_kde",
    # "Hip_IQR_vel_angle": "gaussian_kde",
    # "Hip_entropy_angle": "beta",
    # "Hip_lrCorr_angle": "beta",
    # # "Hip_lrCorr_x": "gaussian_kde",  # ***
    "Hip_mean_angle": "gaussian_kde",
    "Hip_median_vel_angle": "beta",
    "Hip_stdev_angle": "gaussian_kde",
    # "Knee_IQR_acc_angle": "gaussian_kde",
    # "Knee_IQR_vel_angle": "gaussian_kde",
    # "Knee_entropy_angle": "gaussian_kde",
    # "Knee_lrCorr_angle": "norm",
    "Knee_lrCorr_x": "gaussian_kde",
    "Knee_mean_angle": "gaussian_kde",
    "Knee_median_vel_angle": "gaussian_kde",
    "Knee_stdev_angle": "beta",
    # "Shoulder_IQR_acc_angle": "gaussian_kde",
    # "Shoulder_IQR_vel_angle": "gaussian_kde",
    # "Shoulder_entropy_angle": "norm",
    # "Shoulder_lrCorr_angle": "norm",
    # "Shoulder_lrCorr_x": "beta",
    # "Shoulder_mean_angle": "gaussian_kde",
    "Shoulder_median_vel_angle": "gaussian_kde",
    # "Shoulder_stdev_angle": "gaussian_kde",
    # "Wrist_IQRaccx": "gaussian_kde",
    # "Wrist_IQRaccy": "gaussian_kde",
    # "Wrist_IQRvelx": "gaussian_kde",
    # "Wrist_IQRvely": "gaussian_kde",
    "Wrist_IQRx": "gaussian_kde",
    # "Wrist_IQRy": "gaussian_kde",
    "Wrist_lrCorr_x": "gaussian_kde", # beta
    # "Wrist_meanent": "gaussian_kde",
    "Wrist_medianvelx": "gaussian_kde",
    "Wrist_medianvely": "gaussian_kde",
    "Wrist_medianx": "gaussian_kde",
    # "Wrist_mediany": "gaussian_kde",
    'age_in_weeks': "gaussian_kde",
}



core_cols = [
    "infant",
    "category",
    "risk",
]

category_risk_constraint = {
    "constraint_class": "FixedCombinations",
    "constraint_parameters": {
        "column_names": ["risk", "category"]
    }
}


bd = BayesianData()

df = bd.base_wide.select(core_cols + keeps)

metadata = Metadata.detect_from_dataframe(
    data=df.to_pandas(),
    table_name="features",
)
metadata.update_column(
    column_name="risk",
    sdtype = "categorical"
)
metadata.save_to_json(JSON_DIR / "sdv_metadata.json", mode="overwrite")

# metadata.visualize()

synthesizer = GaussianCopulaSynthesizer(
    metadata,
    numerical_distributions=distros,
    enforce_min_max_values=True,
)

# synthesizer = CopulaGANSynthesizer(
#     metadata, # required
#     enforce_min_max_values=True,
#     enforce_rounding=False,
#     numerical_distributions=distros,
#     epochs=500,
#     verbose=True
# )

# synthesizer = CTGANSynthesizer(
#     metadata=metadata,
#     enforce_rounding=True,
#     enforce_min_max_values=True,
#     epochs=500,
# )

# synthesizer = CopulaGANSynthesizer(
#     metadata=metadata,
#     enforce_rounding=True,
#     enforce_min_max_values=True,
#     epochs=500,
# )
synthesizer.add_constraints(
    constraints=[
        category_risk_constraint
    ]
)

synthesizer.fit(df.to_pandas())

data_synth = synthesizer.sample(NUM_ROWS)
# synthesizer.get_loss_values_plot()

Sampling rows: 100%|██████████| 1000/1000 [00:00<00:00, 1979.74it/s]


### 2.2 Save Synthetic Data

This cell takes the generated synthetic data, converts it to a Polars DataFrame, reshapes it from wide to long format using `unpivot`, and saves the final result to an IPC file for use in downstream analyses.

In [2]:
df_synth = pl.DataFrame(data_synth)
df_synth.unpivot(index=["infant", "category", "risk"], variable_name="feature", value_name="value").write_ipc(IPC_DIR / f"synth_sdv_{NUM_ROWS}_long.ipc")

### 2.3 Evaluate Data Quality

This cell uses `sdv`'s built-in evaluation tools to quantitatively assess the quality of the synthetic data. It runs a `diagnostic` to check for basic data validity and then generates a `quality_report` that provides scores for column shapes and column pair trends, offering a robust measure of how well the synthetic data mimics the real data.

In [3]:
diagnostic = run_diagnostic(df.to_pandas(), data_synth, metadata)
quality_report = evaluate_quality(df.to_pandas(), data_synth, metadata)


Generating report ...

(1/2) Evaluating Data Validity: |██████████| 26/26 [00:00<00:00, 5585.53it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████| 1/1 [00:00<00:00, 920.01it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 26/26 [00:00<00:00, 2575.75it/s]|
Column Shapes Score: 94.15%

(2/2) Evaluating Column Pair Trends: |██████████| 325/325 [00:00<00:00, 1008.28it/s]|
Column Pair Trends Score: 95.7%

Overall Score (Average): 94.93%



### 2.4 Visualize Distribution Comparisons

This cell generates a comprehensive visualization comparing the distributions of the real and synthetic data for each feature in the `keeps` list. It uses `sdmetrics.visualization.get_column_plot` to create individual distribution plots and then combines them into a single, large subplot grid using Plotly for easy comparison.

In [4]:
from sdmetrics.visualization import get_column_plot
from plotly.subplots import make_subplots


figures = []
for f in keeps:
    figures.append(
        get_column_plot(
            real_data=df.to_pandas(),
            synthetic_data=data_synth,
            column_name=f,
            plot_type='distplot'
        )
    )

fig_main = make_subplots(
    rows=math.ceil(len(figures) / 2),
    cols=2,
    subplot_titles=keeps,
)

for i, f in enumerate(figures):
    row = i // 2 + 1
    col = i % 2 + 1
    for trace in range(len(f["data"])):
        fig_main.add_trace(trace=f["data"][trace], row=row, col=col)

fig_main.update_layout(
    height=3600,
    showlegend=False,
    plot_bgcolor="white"
)
fig_main.show()

### 2.5 Display Quality Report and Learned Distributions

These final cells display the SDV quality report's column shape visualizations directly in the notebook and print the learned statistical distributions for each feature, providing a detailed look into the parameters of the fitted synthesizer model.

In [5]:
quality_report.get_visualization("Column Shapes")

In [6]:
synthesizer.get_learned_distributions()

{'Ankle_IQRx': {'distribution': 'gaussian_kde',
  'learned_parameters': {'dataset': [0.2344,
    0.1609,
    0.0274,
    0.073,
    0.1475,
    0.4635,
    0.1947,
    0.2413,
    0.1254,
    0.5452,
    0.5733,
    0.1712,
    0.3155,
    0.2843,
    0.1617,
    0.1601,
    0.32,
    0.1474,
    0.4187,
    0.1399,
    0.1704,
    0.1893,
    0.1888,
    0.2518,
    0.2044,
    0.1604,
    0.1415,
    0.156,
    0.2236,
    0.1916,
    0.2179,
    0.2479,
    0.1517,
    0.2069,
    0.2956,
    0.2433,
    0.1858,
    0.16,
    0.1628,
    0.1795,
    0.1708,
    0.2368,
    0.1421,
    0.1492,
    0.1904,
    0.1901,
    0.1386,
    0.2669,
    0.2666,
    0.2343,
    0.1378,
    0.1773,
    0.1302,
    0.1822,
    0.1679,
    0.2917,
    0.4027,
    0.1798,
    0.2118,
    0.1763,
    0.2055,
    0.2125,
    0.1661,
    0.4247,
    0.1281,
    0.3618,
    0.14,
    0.2276,
    0.1122,
    0.1548,
    0.152,
    0.244,
    0.1191,
    0.1416,
    0.1407,
    0.2624,
    0.2204,
    0

In [7]:
# from sdv.evaluation.single_table import get_column_pair_plot
#
# fig = get_column_pair_plot(
#     real_data=df.to_pandas(),
#     synthetic_data=data_synth,
#     metadata=metadata,
#     column_names= ['Ankle_IQRx', 'Ankle_IQRy',],
#     )
#
# fig.show()