# Synthetic Data Generator

Followiung the documentation available at [Gretel.ai's Synthetics library](https://github.com/gretelai/gretel-synthetics) (Apache License 2.0), we can adapt it to create synthetic data for our needs. The primary goal is to generate synthetic *health* data, but it can be used to generate any dataset once given a sample to train on.

In [31]:
import pandas as pd
from pathlib import Path
from gretel_synthetics.generate import generate_text
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.tokenizers import CharTokenizerTrainer
from gretel_synthetics.train import train

In [10]:
## Convert XPT file to CSV for Gretel
# mcq = pd.read_sas('data/P_MCQ.XPT')
# mcq.to_csv('data/P_MCQ.csv')

In [37]:
input_data = pd.read_csv('data/P_MCQ.csv')

medcols = {
    'SEQN' : 'Respondent sequence number',
    'MCQ010' : 'Ever been told you have asthma',
    'MCQ025' : 'Age when first had asthma',
    'MCQ035' : 'Still have asthma',
    'MCQ040' : 'Had asthma attack in past year',
    'MCQ050' : 'Emergency care visit for asthma/past yr',
    'AGQ030' : 'Did SP have episode of hay fever/past yr',
    'MCQ053' : 'Taking treatment for anemia/past 3 mos',
    'MCQ080' : 'Doctor ever said you were overweight',
    'MCQ092' : 'Ever receive blood transfusion',
    'MCD093' : 'Year receive blood transfusion',
    'MCQ145' : 'CHECK ITEM',
    'MCQ149' : 'Menstrual periods started yet?',
    'MCQ151' : 'Age in years at first menstrual period',
    'RHD018' : 'Estimated age in months at menarche',
    'MCQ160a' : 'Doctor ever said you had arthritis',
    'MCQ195' : 'Which type of arthritis was it?',
    'MCQ160B' : 'Ever told had congestive heart failure',
    'MCD180B' : 'Age when told you had heart failure',
    'MCQ160C' : 'Ever told you had coronary heart disease',
    'MCD180C' : 'Age when told had coronary heart disease',
    'MCQ160D' : 'Ever told you had angina/angina pectoris',
    'MCD180D' : 'Age when told you had angina pectoris',
    'MCQ160E' : 'Ever told you had heart attack',
    'MCD180E' : 'Age when told you had heart attack',
    'MCQ160F' : 'Ever told you had a stroke',
    'MCD180F' : 'Age when told you had a stroke',
    'MCQ160M' : 'Ever told you had thyroid problem',
    'MCQ170M' : 'Do you still have thyroid problem',
    'MCD180M' : 'Age when told you had thyroid problem',
    'MCQ160P' : 'Ever told you had COPD, emphysema, ChB',
    'MCQ160L' : 'Ever told you had any liver condition',
    'MCQ170L' : 'Do you still have a liver condition',
    'MCD180L' : 'Age when told you had a liver condition',
    'MCQ500' : 'Ever told you had any liver condition',
    'MCQ510A' : 'Liver condition: Fatty liver',
    'MCQ510B' : 'Liver condition: Liver fibrosis',
    'MCQ510C' : 'Liver condition: Liver cirrhosis',
    'MCQ510D' : 'Liver condition: Viral hepatitis',
    'MCQ510E' : 'Liver condition: Autoimmune hepatitis',
    'MCQ510F' : 'Liver condition: Other liver disease',
    'MCQ515' : 'CHECK ITEM',
    'MCQ520' : 'Abdominal pain during past 12 months?',
    'MCQ530' : 'Where was the most uncomfortable pain',
    'MCQ540' : 'Ever seen a DR about this pain',
    'MCQ550' : 'Has DR ever said you have gallstones',
    'MCQ560' : 'Ever had gallbladder surgery?',
    'MCQ570' : 'Age when 1st had gallbladder surgery?',
    'MCQ220' : 'Ever told you had cancer or malignancy',
    'MCQ230A' : '1st cancer - what kind was it?',
    'MCQ230B' : '2nd cancer - what kind was it?',
    'MCQ230C' : '3rd cancer - what kind was it?',
    'MCQ230D' : 'More than 3 kinds of cancer',
    'MCQ300B' : 'Close relative had asthma?',
    'MCQ300C' : 'Close relative had diabetes?',
    'MCQ300A' : 'Close relative had heart attack?',
    'MCQ366A' : 'Doctor told you to control/lose weight',
    'MCQ366B' : 'Doctor told you to exercise',
    'MCQ366C' : 'Doctor told you to reduce salt in diet',
    'MCQ366D' : 'Doctor told you to reduce fat/calories',
    'MCQ371A' : 'Are you now controlling or losing weight',
    'MCQ371B' : 'Are you now increasing exercise',
    'MCQ371C' : 'Are you now reducing salt in diet',
    'MCQ371D' : 'Are you now reducing fat in diet',
    'OSQ230' : 'Any metal objects inside your body?'
}

input_data = input_data.rename(medcols, axis=1)

In [18]:
# Configure TensorFlow. Arguments with 'dp' refer to intentionally noising data with differential privacy

config = TensorFlowConfig(
    epochs=5,
    gen_lines=1000,
    max_lines=1e5,
    dp=False,
    predict_batch_size=1,
    rnn_units=256,
    batch_size=16,
    learning_rate=0.0015,
    # dp_noise_multiplier=0.2,
    # dp_l2_norm_clip=1.0,
    dropout_rate=0.5,
    # dp_microbatches=1,
    reset_states=False,
    overwrite=True,
    checkpoint_dir=(Path.cwd() / 'checkpoints').as_posix(),
    input_data_path='data/P_MCQ.csv'
)

In [19]:
# Initialize the tokenizer
tokenizer = CharTokenizerTrainer(config=config)

# Train the model
train(config, tokenizer)

2022-05-30 11:43:32,415 : MainThread : INFO : Loading input data from data/P_MCQ.csv
2022-05-30 11:43:32,494 : MainThread : INFO : Tokenizing input data
100%|██████████| 14987/14987 [00:00<00:00, 66212.78it/s]
2022-05-30 11:43:32,723 : MainThread : INFO : Shuffling input data
2022-05-30 11:43:36,355 : MainThread : INFO : Creating validation dataset
2022-05-30 11:43:36,379 : MainThread : INFO : Creating training dataset
2022-05-30 11:43:39,478 : MainThread : INFO : Initializing synthetic model
2022-05-30 11:43:39,835 : MainThread : INFO : Using keras.optimizers.RMSprop optimizer


Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_3 (Embedding)     (16, None, 256)           7424      
                                                                 
 dropout_9 (Dropout)         (16, None, 256)           0         
                                                                 
 lstm_6 (LSTM)               (16, None, 256)           525312    
                                                                 
 dropout_10 (Dropout)        (16, None, 256)           0         
                                                                 
 lstm_7 (LSTM)               (16, None, 256)           525312    
                                                                 
 dropout_11 (Dropout)        (16, None, 256)           0         
                                                                 
 dense_3 (Dense)             (16, None, 29)           

2022-05-30 12:04:50,787 : MainThread : INFO : Saving model history to model_history.csv
2022-05-30 12:04:50,789 : MainThread : INFO : Saving model to /Users/jordanswanson/Documents/GitHub/lesson-plans/notebooks/health-data-privacy/checkpoints/synthetic


In [33]:
# Build a validator
def validate_record(line):
    rec = line.split(",")
    if len(rec) == 4:
        datetime.datetime.strptime(rec[3], '%Y-%m-%d')
        int(rec[2])
        int(rec[1])
        int(rec[0])
    else:
        raise Exception('record not valid')

In [None]:
# Print CSV header and synthetic lines
counter = 0
print(input_data.columns)
for line in generate_text(config, 
                          line_validator=validate_record, 
                          max_invalid=1e5):
    if line.valid:
        print(f"{line.text}")
        counter += 1
    if counter > config.gen_lines:
        break