# 07 - Synthetic Data Generation with Gretel

**Purpose**: This notebook uses the Gretel platform's API to train a synthetic data model and generate artificial records. It demonstrates an alternative, cloud-based approach to synthetic data generation.

**Inputs**:
- Real data, accessed via the  class.

**Outputs**:
- A trained model in the Gretel Cloud.
- Downloaded model artifacts (e.g., reports) saved to the .

### 7.1 Setup and Gretel Model Submission

This cell prepares the data and submits a model training job to the Gretel Cloud:
1.  **Setup**: Imports the  and defines constants, including column name mappings () to shorten feature names for the Gretel API.
2.  **Data Preparation**: Loads the real data using , renames the columns to the shortened versions, and converts it to a Pandas DataFrame.
3.  **Model Configuration**: Defines the configuration for a Gretel ACTGAN model, including hyperparameters like , , and the number of records to generate.
4.  **Job Submission**: Creates a unique project in Gretel, creates a model object with the specified configuration and data, and submits it to the Gretel Cloud for training.

In [None]:
%reload_ext autoreload
%autoreload 2

import math
from datetime import datetime

import polars as pl
from polars import DataFrame
import polars.selectors as cs
import pandas as pd

from gretel_client import Gretel
from gretel_client.projects.models import read_model_config
from gretel_client import create_or_get_unique_project
from gretel_client import poll

from early_markers.cribsy.common.constants import JSON_DIR, IPC_DIR, RAND_STATE, FEATURES, HTML_DIR, CSV_DIR
from early_markers.cribsy.common.bayes import BayesianData

# Gretel API Key: grtu850b2c9364b9e16d30aa335eeb12a40df96a9565e8b0baddb0fa5d2d6d1c1eaa

NUM_ROWS = 2000
RISK_0_ROWS = math.ceil(0.740741 * NUM_ROWS)
RISK_1_ROWS = NUM_ROWS - RISK_0_ROWS

TODAY = datetime.today().strftime("%Y%m%d")

COL_MAP = {
    "category": "cat",
    "risk": "rsk",
    "age_bracket": "age",
    "Ankle_IQRaccx": "an_1",
    "Ankle_IQRaccy": "an_2",
    "Ankle_IQRvelx": "an_3",
    "Ankle_IQRvely": "an_4",
    "Ankle_IQRx": "an_5",
    "Ankle_IQRy": "an_6",
    "Ankle_lrCorr_x": "an_7",
    "Ankle_meanent": "an_8",
    "Ankle_medianvelx": "an_9",
    "Ankle_medianvely": "an_10",
    "Ankle_medianx": "an_11",
    "Ankle_mediany": "an_12",
    "Ear_lrCorr_x": "ea_1",
    "Elbow_IQR_acc_angle": "el_1",
    "Elbow_IQR_vel_angle": "el_2",
    "Elbow_entropy_angle": "el_3",
    "Elbow_lrCorr_angle": "el_4",
    "Elbow_lrCorr_x": "el_5",
    "Elbow_mean_angle": "el_6",
    "Elbow_median_vel_angle": "el_7",
    "Elbow_stdev_angle": "el_8",
    "Eye_lrCorr_x": "ey_5",
    "Hip_IQR_acc_angle": "hi_1",
    "Hip_IQR_vel_angle": "hi_2",
    "Hip_entropy_angle": "hi_3",
    "Hip_lrCorr_angle": "hi_4",
    "Hip_mean_angle": "hi_5",
    "Hip_median_vel_angle": "hi_6",
    "Hip_stdev_angle": "hi_7",
    "Knee_IQR_acc_angle": "kn_1",
    "Knee_IQR_vel_angle": "kn_2",
    "Knee_entropy_angle": "kn_3",
    "Knee_lrCorr_angle": "kn_4",
    "Knee_lrCorr_x": "kn_5",
    "Knee_mean_angle": "kn_6",
    "Knee_median_vel_angle": "kn_7",
    "Knee_stdev_angle": "kn_8",
    "Shoulder_IQR_acc_angle": "sh_1",
    "Shoulder_IQR_vel_angle": "sh_2",
    "Shoulder_entropy_angle": "sh_3",
    "Shoulder_lrCorr_angle": "sh_4",
    "Shoulder_mean_angle": "sh_5",
    "Shoulder_median_vel_angle": "sh_6",
    "Shoulder_stdev_angle": "sh_7",
    "Wrist_IQRaccx": "wr_1",
    "Wrist_IQRaccy": "wr_2",
    "Wrist_IQRvelx": "wr_3",
    "Wrist_IQRvely": "wr_4",
    "Wrist_IQRx": "wr_5",
    "Wrist_IQRy": "wr_6",
    "Wrist_lrCorr_x": "wr_7",
    "Wrist_meanent": "wr_8",
    "Wrist_medianvelx": "wr_9",
    "Wrist_medianvely": "wr_10",
    "Wrist_medianx": "wr_11",
    "Wrist_mediany": "wr_12",
}

REV_MAP = {v: k for k, v in COL_MAP.items()}

REAL_42 = [
    "category",
    "risk",
    "age_bracket",
    "Ankle_IQRaccx",
    "Ankle_IQRaccy",
    "Ankle_IQRvelx",
    "Ankle_IQRvely",
    "Ankle_IQRx",
    "Ankle_IQRy",
    "Ankle_meanent",
    "Ankle_medianvelx",
    "Ankle_medianvely",
    "Ankle_medianx",
    "Ankle_mediany",
    "Elbow_IQR_acc_angle",
    "Elbow_IQR_vel_angle",
    "Elbow_entropy_angle",
    "Elbow_mean_angle",
    "Elbow_median_vel_angle",
    "Hip_IQR_acc_angle",
    "Hip_IQR_vel_angle",
    "Hip_entropy_angle",
    "Hip_mean_angle",
    "Hip_median_vel_angle",
    "Hip_stdev_angle",
    "Knee_IQR_acc_angle",
    "Knee_IQR_vel_angle",
    "Knee_entropy_angle",
    "Knee_lrCorr_angle",
    "Knee_median_vel_angle",
    "Knee_stdev_angle",
    "Shoulder_IQR_acc_angle",
    "Shoulder_IQR_vel_angle",
    "Shoulder_mean_angle",
    "Shoulder_median_vel_angle",
    "Wrist_IQRaccx",
    "Wrist_IQRaccy",
    "Wrist_IQRvelx",
    "Wrist_IQRvely",
    "Wrist_IQRx",
    "Wrist_IQRy",
    "Wrist_meanent",
    "Wrist_medianvelx",
    "Wrist_medianvely",
    "Wrist_mediany",
]

KEEPS = {k: v for k, v in COL_MAP.items() if k in REAL_42}
REV_KEEPS = {v: k for k, v in KEEPS.items()}

K = 58
TRIAL = 1
SYNTHETIC_N = 5000

bd = BayesianData()

df_all = bd.base_wide.with_columns(
    pl.col(pl.Float64).round(4)
).select(COL_MAP.keys()).rename(COL_MAP)

config = {
    "schema_version": "1.0",
    "name": f"emma_all_k{K}_{TRIAL}",
    "models": [
        {
            "actgan": {
                "data_source": "__tmp__",
                "params": {
                    "epochs": "auto",
                    "generator_dim": [1024, 1024],
                    "discriminator_dim": [1024, 1024],
                    "generator_lr": 0.0001,
                    "discriminator_lr": 0.00033,
                    "batch_size": "auto",
                    "auto_transform_datetimes": False
                },
                "generate": {
                    "num_records": SYNTHETIC_N
                },
                "privacy_filters": {
                    "outliers": None, "similarity": None
                }
            }
        }
    ]
}

gretel = Gretel(api_key="grtu850b2c9364b9e16d30aa335eeb12a40df96a9565e8b0baddb0fa5d2d6d1c1eaa")

proj = create_or_get_unique_project(name="cribsy-sample-size", display_name="Cribsy Sample Size")
model = proj.create_model_obj(model_config=config, data_source=df_all.to_pandas())
model.submit_cloud()

print(f"Completed @ {datetime.now()}")


### 7.2 Download Model Artifacts

This cell connects to the completed Gretel model object and downloads the associated artifacts, such as the quality report, to the local  for inspection.

In [None]:
model.download_artifacts(HTML_DIR / f"emmal_all_k{K}_{TRIAL}")

### 7.3 Poll for Model Status

This cell uses the  function to monitor the status of the Gretel Cloud training job, providing real-time feedback as the model progresses through the , , and  states.

In [None]:
# Model(id=67eed5344f7b056cda725f19, project=proj_2vELdPVVBEaVYFzQ8XGMOSkfhru)
# read_model_config("evaluate/default")
# read_model_config("tuner/tabular-actgan")
poll(model)