/
conditional-generation.py
40 lines (33 loc) · 1.18 KB
/
conditional-generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pandas as pd
from gretel_client import configure_session
from gretel_trainer import Trainer
from gretel_trainer.models import GretelACTGAN, GretelLSTM
DATASET_PATH = (
"https://gretel-public-website.s3.amazonaws.com/datasets/mitre-synthea-health.csv"
)
MODEL_TYPE = [GretelLSTM(), GretelACTGAN()][1]
# Create dataset to autocomplete values for
seed_df = pd.DataFrame(
data=[
["black", "african", "F"],
["black", "african", "F"],
["black", "african", "F"],
["black", "african", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
["asian", "chinese", "F"],
],
columns=["RACE", "ETHNICITY", "GENDER"],
)
# Configure Gretel credentials
configure_session(api_key="prompt", cache="yes", validate=True)
# Train a model and conditionally generate data
seed_fields = seed_df.columns.values.tolist()
model = Trainer(model_type=MODEL_TYPE)
model.train(DATASET_PATH, seed_fields=seed_fields)
print(model.generate(seed_df=seed_df))
# Load a existing model and conditionally generate data
# model = Trainer.load()
# print(model.generate(seed_df=seed_df))