# Create synthetic mouse phenome data

Create a synthetic version of the mouse phenomes from the original experiment, which are available after running `01_create_phenome_training_data.ipynb`. To run this notebook, you will need an API key from the Gretel console,  at https://console.gretel.cloud.

In [2]:
%%capture
!python3 -m pip install -U gretel-client

In [1]:
# Specify your Gretel API key

from getpass import getpass
import pandas as pd
from gretel_client import configure_session, ClientConfig

pd.set_option('max_colwidth', None)

configure_session(ClientConfig(api_key=getpass(prompt="Enter Gretel API key"), 
                               endpoint="https://api.gretel.cloud"))

                            

Enter Gretel API key········


## Configure model hyper parameters
Load the default configuration template. This template will work well for most datasets. View other templates at https://github.com/gretelai/gretel-blueprints/tree/main/config_templates/gretel/synthetics

In [2]:
import json
from smart_open import open
import yaml

with open("https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/config_templates/gretel/synthetics/default.yml", 'r') as stream:
    config = yaml.safe_load(stream)

# Optimize parameters for complex dataset
config['models'][0]['synthetics']['params']['epochs'] = 200
config['models'][0]['synthetics']['params']['vocab_size'] = 0
config['models'][0]['synthetics']['params']['rnn_units'] = 640
config['models'][0]['synthetics']['params']['reset_states'] = False
config['models'][0]['synthetics']['params']['learning_rate'] = 0.001
config['models'][0]['synthetics']['params']['dropout_rate'] = 0.4312
config['models'][0]['synthetics']['params']['gen_temp'] = 1.003
config['models'][0]['synthetics']['privacy_filters']['similarity'] = None

print(json.dumps(config, indent=2))

{
  "schema_version": "1.0",
  "models": [
    {
      "synthetics": {
        "data_source": "__tmp__",
        "params": {
          "epochs": 200,
          "batch_size": 64,
          "vocab_size": 0,
          "reset_states": false,
          "learning_rate": 0.001,
          "rnn_units": 640,
          "dropout_rate": 0.4312,
          "overwrite": true,
          "early_stopping": true,
          "gen_temp": 1.003,
          "predict_batch_size": 64,
          "validation_split": false,
          "dp": false,
          "dp_noise_multiplier": 0.001,
          "dp_l2_norm_clip": 5.0,
          "dp_microbatches": 1
        },
        "validators": {
          "in_set_count": 10,
          "pattern_count": 10
        },
        "generate": {
          "num_records": 5000,
          "max_invalid": null
        },
        "privacy_filters": {
          "outliers": "medium",
          "similarity": null
        }
      }
    }
  ]
}


## Train the synthetic model
In this step, we will task the worker running in the Gretel cloud, or locally, to train a synthetic model on the source dataset.

In [3]:
# Get the location of the phenome training data

import os
import pathlib

base_path = pathlib.Path(os.getcwd().replace("/synthetics", ""))
data_path = base_path / 'mice_data_set' / 'data'

In [4]:
# Define a function to submit a new model for a specific phenome batch dataset

def create_model(batch_num):
    seconds = int(time.time())
    project_name = "Training phenomes" + str(seconds)
    project = create_project(display_name=project_name)
    batchfile = "pheno_batch" + str(batch_num) + ".csv"
    trainpath = str(data_path / batchfile)
    model = project.create_model_obj(model_config=config)
    model.data_source = trainpath
    model.submit(upload_data_source=True)  
    return(model)

In [5]:
# Submit all the phenome batches to train in parallel; poll for completion

from gretel_client.helpers import poll
from gretel_client import create_project
import time

# Create a model for each batch
models = []
for i in range(7):
    model = create_model(i)
    models.append(model)

# Poll for completion. Resubmit errors.
training = True
while training:
    time.sleep(60)
    training = False
    print()
    for i in range(7):
        model = models[i]
        model._poll_job_endpoint()
        status = model.__dict__['_data']['model']['status']
        print("Batch " + str(i) + " has status: " + status)
        if ((status == "active") or (status == "pending")):
            training = True
        if status == "error":
            model = create_model(i)
            models[i] = model
            training = True           

# Now that models are complete, get each batches Synthetic Quality Score (SQS)            
batch = 0
print()
for model in models:
    model._poll_job_endpoint()
    status = model.__dict__['_data']['model']['status']
    if status == "error":
        print("Batch " + str(batch) + " ended with error")
    else:
        report = model.peek_report()
        sqs = report['synthetic_data_quality_score']['score']
        label = "Moderate"
        if sqs >= 80:
            label = "Excellent"
        elif sqs >= 60:
            label = "Good"
        print("Batch " + str(batch) + " completes with SQS: " + label + " (" + str(sqs) + ")")
    batch += 1


Batch 0 has status: active
Batch 1 has status: active
Batch 2 has status: active
Batch 3 has status: active
Batch 4 has status: active
Batch 5 has status: pending
Batch 6 has status: pending

Batch 0 has status: active
Batch 1 has status: active
Batch 2 has status: active
Batch 3 has status: active
Batch 4 has status: active
Batch 5 has status: active
Batch 6 has status: active

Batch 0 has status: active
Batch 1 has status: active
Batch 2 has status: error
Batch 3 has status: active
Batch 4 has status: active
Batch 5 has status: active
Batch 6 has status: active

Batch 0 has status: completed
Batch 1 has status: active
Batch 2 has status: active
Batch 3 has status: active
Batch 4 has status: active
Batch 5 has status: active
Batch 6 has status: active

Batch 0 has status: completed
Batch 1 has status: completed
Batch 2 has status: active
Batch 3 has status: active
Batch 4 has status: error
Batch 5 has status: active
Batch 6 has status: active

Batch 0 has status: completed
Batch 1 ha

In [6]:
# Read the original phenome set

filename = "phenome_alldata.csv"
filepath = data_path / filename
phenome_orig = pd.read_csv(filepath)

In [7]:
# Merge the synthetic batches into one dataframe
# First gather the synthetic data for each batch

synth_batches = []
for i in range(7):
    model = models[i]
    synth = pd.read_csv(model.get_artifact_link("data_preview"), compression='gzip')
    synth_batches.append(synth)


In [8]:
# Merge batch 0 and 1 on common field sacweight
synth_batches[0]['g'] = synth_batches[0].groupby('sacweight').cumcount()
synth_batches[1]['g'] = synth_batches[1].groupby('sacweight').cumcount()
synth_allbatches = pd.merge(synth_batches[0],synth_batches[1],on=["sacweight", 'g'],how='left').drop('g', axis=1)

# Now merge in batch 2 on common fields SW16, SW20, SW17
synth_allbatches['g'] = synth_allbatches.groupby(['SW16','SW20', 'SW17']).cumcount()
synth_batches[2]['g'] = synth_batches[2].groupby(['SW16', 'SW20', 'SW17']).cumcount()
synth_allbatches = pd.merge(synth_allbatches,synth_batches[2],on=['SW16', 'SW20', 'SW17', 'g'],how='left').drop('g', axis=1)

# Now merge in batches 3 
synth_allbatches = pd.concat([synth_allbatches, synth_batches[3]], axis=1)

# Now merge in batch 4 using common fields 'methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10'
synth_allbatches['g'] = synth_allbatches.groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()
synth_batches[4]['g'] = synth_batches[4].groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()
synth_allbatches = pd.merge(synth_allbatches,synth_batches[4],on=['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10', 'g'],how='left').drop('g', axis=1)

# Now merge in batch 5 using common fields 'methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10'
synth_allbatches['g'] = synth_allbatches.groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()
synth_batches[5]['g'] = synth_batches[5].groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()
synth_allbatches = pd.merge(synth_allbatches,synth_batches[5],on=['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10', 'g'],how='left').drop('g', axis=1)

# Now merge in batches 6
synth_allbatches = pd.concat([synth_allbatches, synth_batches[6]], axis=1)
synth_allbatches



Unnamed: 0,TA,SW16,tibia,EDL,plantaris,gastroc,SW6,sacweight,BMD,abBMD,...,pp6PPIavg,pp12PPIavg,PPIavg,startle,p120b4,PPIbox1,PPIbox2,PPIbox3,PPIbox4,p120b1
0,54.0,0.0,18.39,12.3,19.4,140.2,1.0,32.2,1.90,0.0,...,0.53,0.86,0.48,55.50,45.50,1.0,0.0,0.0,0.0,76.00
1,56.2,0.0,18.30,12.4,17.1,141.7,0.0,39.9,1.86,0.0,...,0.49,0.70,0.50,44.08,53.50,1.0,0.0,0.0,0.0,58.00
2,54.8,0.0,18.24,13.9,17.6,139.4,0.0,36.8,1.89,0.0,...,0.49,0.70,0.50,11.17,11.17,0.0,0.0,0.0,0.0,26.00
3,58.5,0.0,18.69,12.3,19.5,152.5,0.0,43.1,1.92,0.0,...,-2.00,-2.00,-2.00,14.00,9.17,0.0,0.0,1.0,0.0,34.50
4,64.2,0.0,19.01,13.1,18.3,162.1,0.0,53.0,1.89,0.0,...,0.56,0.70,0.56,25.58,15.50,0.0,0.0,0.0,0.0,51.50
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,58.8,0.0,18.58,11.9,19.2,140.6,0.0,37.1,1.89,0.0,...,-0.01,0.28,0.01,8.00,49.67,1.0,0.0,0.0,0.0,15.50
4996,56.3,0.0,18.14,11.4,14.3,151.0,0.0,36.0,1.87,0.0,...,-0.01,0.37,0.11,15.67,7.00,0.0,0.0,1.0,0.0,44.67
4997,62.6,0.0,18.91,10.8,17.9,148.9,0.0,41.6,1.92,0.0,...,0.29,0.58,0.24,14.75,18.33,1.0,0.0,0.0,0.0,32.00
4998,63.5,0.0,18.48,14.1,17.2,148.6,0.0,43.2,1.93,0.0,...,0.21,0.76,0.26,100.67,88.50,1.0,0.0,0.0,0.0,62.83


In [9]:
# Add back in the "id" and "discard" fields, and save off complete synthetic data

id_col = []
discard_col = []
for i in range(len(synth_allbatches.index)):
    id_col.append(i)
    discard_col.append("no")
    
synth_allbatches["id"] = id_col
synth_allbatches["discard"] = discard_col
filepath = data_path / 'phenome_alldata_synth.csv'
synth_allbatches.to_csv(filepath, index=False, header=True)

In [11]:
# Save off the abBMD and SW16 values to be used as seeds when generating the synthetic genome data
# First compute the number of records in the phenome data

batchfile = "pheno_batch0.csv"
trainpath = str(data_path / batchfile)
batch_df = pd.read_csv(trainpath)
num_records = len(batch_df.index)

# Filter abBMD and SW16 and take the same number of examples as in the training data
seeds_df = synth_batches[0].filter(['abBMD', 'SW16']).sample(n=num_records, random_state=1)
seedfile = data_path / 'phenome_abBMD_seeds.csv'
seeds_df.to_csv(seedfile, index=False, header=True)

In [17]:
# Save off the abBMD and SW16 values to be used as seeds when generating the synthetic genome data

seeds_df = synth_batches[0].filter(['abBMD', 'SW16'])
seedfile = data_path / 'phenome_abBMD_seeds.csv'
seeds_df.to_csv(seedfile, index=False, header=True)

In [12]:
# Generate report that shows the statistical performance between the training and synthetic data
# Use the synthetic batch that includes abBMD

from smart_open import open
from IPython.core.display import display, HTML


# Change batch_num to any value between 0 and 6 to view performance report for other batches
batch_num = 0
display(HTML(data=open(models[0].get_artifact_link("report")).read(), metadata=dict(isolated=True)))


0,1,2,3,4,5
Synthetic Data Use Cases,Excellent,Good,Moderate,Poor,Very Poor
Significant tuning required to improve model,,,,,
Improve your model using our tips and advice,,,,,
Demo environments or mock data,,,,,
Pre-production testing environments,,,,,
Balance or augment machine learning data sources,,,,,
Machine learning or statistical analysis,,,,,

0,1,2,3,4
Data Sharing Use Case,Excellent,Very Good,Good,Normal
"Internally, within the same team",,,,
"Internally, across different teams",,,,
"Externally, with trusted partners",,,,
"Externally, public availability",,,,

Unnamed: 0,Training Data,Synthetic Data
Row Count,909,5000
Column Count,10,10
Training Lines Duplicated,--,109

Default Privacy Protections,Advanced Protections

Field,Unique,Missing,Ave. Length,Type,Distribution Stability
tibia,206,0,4.9,Numeric,Good
BMD,42,0,3.9,Numeric,Good
gastroc,492,0,4.98,Numeric,Good
EDL,87,0,3.97,Numeric,Excellent
plantaris,135,0,4.0,Numeric,Excellent
abBMD,2,0,3.0,Categorical,Excellent
TA,261,0,4.0,Numeric,Excellent
sacweight,215,0,4.0,Numeric,Excellent
SW16,2,0,3.0,Categorical,Excellent
SW6,2,0,3.0,Categorical,Excellent
