## Train a Gretel ACTGAN synthetic data model locally

This notebook walks through training a model and generating synthetic data locally in your environment.

Follow the instructions here to set up your local environment and GPU: https://docs.gretel.ai/environment-setup

Prerequisites:

- Python 3.9+ (`python --version`).
- GPU with CUDA configured highly recommended (`nvidia-smi`).
- Ensure that Docker is running (`docker info`).
- The Gretel client SDK is installed and configured (`pip install -U gretel-client; gretel configure`).


In [1]:
import json

from smart_open import open
from pathlib import Path
import pandas as pd

from gretel_client import submit_docker_local
from gretel_client.projects import create_or_get_unique_project

data_source = "https://gretel-public-website.s3.us-west-2.amazonaws.com/datasets/sample-synthetic-healthcare.csv"
target_dir = Path("tmp-actgan")

In [2]:
# Load and preview the DataFrame to train the synthetic model on.

df = pd.read_csv(data_source)
df.to_csv("training_data.csv", index=False)
df


Unnamed: 0,case_id,Hospital_code,Hospital_type_code,City_Code_Hospital,Hospital_region_code,Available Extra Rooms in Hospital,Department,Ward_Type,Ward_Facility_Code,Bed Grade,patientid,City_Code_Patient,Type of Admission,Severity of Illness,Visitors with Patient,Age,Admission_Deposit,Stay
0,1,8,c,3,Z,3,radiotherapy,R,F,2.0,31397,7.0,Emergency,Extreme,2,51-60,4911.0,0-10
1,2,2,c,5,Z,2,radiotherapy,S,F,2.0,31397,7.0,Trauma,Extreme,2,51-60,5954.0,41-50
2,3,10,e,1,X,2,anesthesia,S,E,2.0,31397,7.0,Trauma,Extreme,2,51-60,4745.0,31-40
3,4,26,b,2,Y,2,radiotherapy,R,D,2.0,31397,7.0,Trauma,Extreme,2,51-60,7272.0,41-50
4,5,26,b,2,Y,2,radiotherapy,S,D,2.0,31397,7.0,Trauma,Extreme,2,51-60,5558.0,41-50
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9994,9995,26,b,2,Y,2,radiotherapy,S,D,3.0,26396,8.0,Trauma,Minor,2,11-20,6253.0,21-30
9995,9996,27,a,7,Y,3,gynecology,S,C,3.0,26396,8.0,Trauma,Minor,2,11-20,5312.0,11-20
9996,9997,28,b,11,X,3,gynecology,R,F,3.0,26396,8.0,Trauma,Minor,2,11-20,4843.0,21-30
9997,9998,29,a,4,X,3,gynecology,S,F,2.0,26396,8.0,Trauma,Minor,2,11-20,5997.0,21-30


In [3]:
# Load config and set training parameters
from gretel_client.projects.models import read_model_config

config = read_model_config("synthetics/tabular-actgan")

config["models"][0]["actgan"]["params"]["epochs"] = "auto" # Valid ranges are typically 200-600 epochs
config["models"][0]["actgan"]["privacy_filters"]["similarity"] = None # Enable by changing to "auto" 
config["models"][0]["actgan"]["privacy_filters"]["outliers"] = None # Enable by changing to "auto"
config["models"][0]["actgan"]["data_source"] = "training_data.csv"

print(json.dumps(config, indent=2))


{
  "schema_version": "1.0",
  "name": "tabular-actgan",
  "models": [
    {
      "actgan": {
        "data_source": "training_data.csv",
        "params": {
          "epochs": "auto",
          "generator_dim": [
            1024,
            1024
          ],
          "discriminator_dim": [
            1024,
            1024
          ],
          "generator_lr": 0.0001,
          "discriminator_lr": 0.00033,
          "batch_size": "auto"
        },
        "generate": {
          "num_records": 5000
        },
        "privacy_filters": {
          "outliers": null,
          "similarity": null
        }
      }
    }
  ]
}


In [4]:
# Create a project and train the synthetic data model

target_dir.mkdir(parents=True, exist_ok=True)
project = create_or_get_unique_project(name="actgan-local")
model = project.create_model_obj(model_config=config)
run = submit_docker_local(model, output_dir=target_dir)


INFO: Starting poller
INFO: Status is created. Model creation has been queued.


{
    "uid": "63ffa29ce59169ec864bad78",
    "guid": "model_2MQRCc5H3MBRd3h6jIoJ3160Jw9",
    "model_name": "tabular-actgan",
    "runner_mode": "manual",
    "user_id": "60ec86ce492fbf1c604a6ea5",
    "user_guid": "user_26U3XeNlVkqbJkvxJp0vcBs0LJQ",
    "billing_domain": "gretel.ai",
    "billing_domain_guid": "domain_28bzIokk1eQdWUYsovba0VN1gtY",
    "project_id": "63ff81bba507100935e704ae",
    "project_guid": "proj_2MQA8pmBQ4DC2wGNeM938kx2pap",
    "status_history": {
        "created": "2023-03-01T19:08:12.776000Z"
    },
    "last_modified": "2023-03-01T19:08:12.911000Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "annotations": null,
    "container_image": "074762682575.dkr.ecr.us-east-2.amazonaws.com/models/actgan@sha256:1b8fd472f046a15773116ce0101016e295f6c6cdd4b6fc6042365c789bda3a0a",
    "container_image_version": "4d0fd64c",
    "model_type": "actgan",
    "model_

INFO: Status is active. A worker has started creating your model!
2023-03-01T19:08:33.712249Z  Analyzing input data and checking for auto-params...
2023-03-01T19:08:33.773119Z  Found 2 auto-params that were set based on input data.
{
    "epochs": 600,
    "batch_size": 600
}
2023-03-01T19:08:33.776135Z  Using updated model configuration: 
{
    "schema_version": "1.0",
    "name": "tabular-actgan",
    "models": [
        {
            "actgan": {
                "privacy_filters": {
                    "outliers": null,
                    "similarity": null,
                    "max_iterations": 10
                },
                "data_source": [
                    "training_data.csv"
                ],
                "ref_data": {},
                "params": {
                    "embedding_dim": 128,
                    "generator_dim": [
                        1024,
                        1024
                    ],
                    "discriminator_dim": [
              

In [5]:
# View the generated synthetic data

synthetic_df = pd.read_csv(target_dir / "data_preview.gz", compression="gzip")
synthetic_df


Unnamed: 0,case_id,Hospital_code,Hospital_type_code,City_Code_Hospital,Hospital_region_code,Available Extra Rooms in Hospital,Department,Ward_Type,Ward_Facility_Code,Bed Grade,patientid,City_Code_Patient,Type of Admission,Severity of Illness,Visitors with Patient,Age,Admission_Deposit,Stay
0,3789,27,b,6,X,3,gynecology,R,F,3.0,117412,1.0,Emergency,Moderate,8,61-70,3618.0,31-40
1,801,31,g,9,Y,4,gynecology,Q,B,2.0,59835,15.0,Emergency,Minor,5,41-50,4115.0,31-40
2,4869,19,a,7,Y,3,gynecology,R,C,3.0,38543,8.0,Emergency,Moderate,2,31-40,4770.0,11-20
3,8980,6,b,4,X,2,anesthesia,R,F,3.0,120102,,Trauma,Moderate,3,31-40,4036.0,11-20
4,8046,1,c,3,Z,2,gynecology,P,F,2.0,8527,8.0,Emergency,Moderate,4,31-40,6261.0,81-90
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,7140,26,b,2,Y,2,anesthesia,S,D,3.0,101065,8.0,Emergency,Moderate,2,71-80,6039.0,41-50
4996,1838,25,b,2,Y,2,radiotherapy,Q,D,3.0,68951,8.0,Trauma,Extreme,2,11-20,5828.0,11-20
4997,9153,26,a,2,Y,4,gynecology,Q,D,3.0,15409,24.0,Emergency,Moderate,3,61-70,4936.0,31-40
4998,1657,17,c,3,Z,2,surgery,Q,A,2.0,2989,8.0,Emergency,Moderate,4,51-60,3676.0,11-20


In [9]:
# View report that shows the statistical performance between the training and synthetic data

import IPython

IPython.display.HTML(data=open(target_dir / "report.html.gz").read(), metadata=dict(isolated=True))


0,1,2,3,4,5
Synthetic Data Use Cases,Excellent,Good,Moderate,Poor,Very Poor
Significant tuning required to improve model,,,,,
Improve your model using our tips and advice,,,,,
Demo environments or mock data,,,,,
Pre-production testing environments,,,,,
Balance or augment machine learning data sources,,,,,
Machine learning or statistical analysis,,,,,

0,1,2,3,4,5
Data Sharing Use Case,Excellent,Very Good,Good,Normal,Poor
"Internally, within the same team",,,,,
"Internally, across different teams",,,,,
"Externally, with trusted partners",,,,,
"Externally, public availability",,,,,

Unnamed: 0,Training Data,Synthetic Data
Row Count,250,250
Column Count,18,18
Training Lines Duplicated,--,0

Default Privacy Protections,Advanced Protections

Field,Unique,Missing,Ave. Length,Type,Distribution Stability
City_Code_Patient,20,6,3.08,Numeric,Moderate
patientid,226,0,5.19,Numeric,Good
City_Code_Hospital,11,0,1.11,Numeric,Good
Admission_Deposit,245,0,6.0,Numeric,Good
Bed Grade,4,0,3.0,Numeric,Good
Hospital_code,29,0,1.85,Numeric,Good
Visitors with Patient,13,0,1.02,Categorical,Good
case_id,250,0,3.9,Numeric,Good
Age,10,0,5.01,Categorical,Good
Available Extra Rooms in Hospital,8,0,1.0,Categorical,Good


In [7]:
# Use the trained model to create additional synthetic data

record_handler = model.create_record_handler_obj(params={"num_records": 100000})

run = submit_docker_local(
    record_handler, model_path=target_dir / "model.tar.gz", output_dir=target_dir
)


INFO: Starting poller
INFO: Status is created. A job has been queued.


{
    "uid": "63ffa3e44cc0e7450bfa7ee9",
    "guid": "model_run_2MQRroh6knxHCoIDjcpWeiZoUcX",
    "model_name": null,
    "runner_mode": "manual",
    "user_id": "60ec86ce492fbf1c604a6ea5",
    "user_guid": "user_26U3XeNlVkqbJkvxJp0vcBs0LJQ",
    "billing_domain": "gretel.ai",
    "billing_domain_guid": "domain_28bzIokk1eQdWUYsovba0VN1gtY",
    "project_id": "63ff81bba507100935e704ae",
    "project_guid": "proj_2MQA8pmBQ4DC2wGNeM938kx2pap",
    "status_history": {
        "created": "2023-03-01T19:13:40.569000Z"
    },
    "last_modified": "2023-03-01T19:13:40.728000Z",
    "status": "created",
    "last_active_hb": null,
    "duration_minutes": null,
    "error_msg": null,
    "error_id": null,
    "traceback": null,
    "annotations": null,
    "container_image": "074762682575.dkr.ecr.us-east-2.amazonaws.com/models/actgan@sha256:1b8fd472f046a15773116ce0101016e295f6c6cdd4b6fc6042365c789bda3a0a",
    "container_image_version": "4d0fd64c",
    "model_id": "63ffa29ce59169ec864bad78",
   

INFO: Status is active. A worker has started!
2023-03-01T19:14:03.813133Z  Loading model to worker
2023-03-01T19:14:21.354256Z  Loading ACTGAN model...
2023-03-01T19:14:23.833611Z  Sampling 100000 records...
2023-03-01T19:14:30.412557Z  Preparing privacy filters
2023-03-01T19:14:30.413569Z  Loaded 0 privacy filters
2023-03-01T19:14:30.413950Z  Starting privacy filtering
2023-03-01T19:14:30.414349Z  Privacy filtering complete.


In [8]:
synthetic_df_new = pd.read_csv(target_dir / "data.gz", compression="gzip")
synthetic_df_new


Unnamed: 0,case_id,Hospital_code,Hospital_type_code,City_Code_Hospital,Hospital_region_code,Available Extra Rooms in Hospital,Department,Ward_Type,Ward_Facility_Code,Bed Grade,patientid,City_Code_Patient,Type of Admission,Severity of Illness,Visitors with Patient,Age,Admission_Deposit,Stay
0,305,20,c,3,Z,3,gynecology,R,A,3.0,83024,8.0,Trauma,Moderate,4,41-50,5256.0,11-20
1,6502,20,c,3,Z,3,gynecology,R,A,3.0,41265,23.0,Trauma,Moderate,2,51-60,5549.0,91-100
2,9082,19,a,9,Y,4,surgery,R,B,2.0,81826,6.0,Trauma,Moderate,2,21-30,4060.0,51-60
3,2068,9,d,6,Z,3,gynecology,R,F,2.0,2692,1.0,Trauma,Moderate,5,41-50,5014.0,51-60
4,7024,25,b,7,Y,2,gynecology,Q,B,,23995,8.0,Emergency,Moderate,3,61-70,5006.0,31-40
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,7026,23,d,5,Z,3,gynecology,R,F,2.0,9466,8.0,Emergency,Moderate,2,71-80,4869.0,11-20
99996,7082,5,a,6,X,3,gynecology,R,F,3.0,123748,1.0,Trauma,Moderate,3,31-40,5707.0,11-20
99997,6030,10,b,2,Y,3,gynecology,Q,D,3.0,72569,8.0,Trauma,Minor,2,31-40,5985.0,11-20
99998,4178,23,a,6,X,2,gynecology,R,F,4.0,62315,1.0,Emergency,Extreme,4,71-80,3833.0,31-40
